In the course you learned how to do classification using Fashion MNIST, a data set containing items of clothing. There's another, similar dataset called MNIST which has items of handwriting -- the digits 0 through 9.

Write an MNIST classifier that trains to 99% accuracy or above, and does it without a fixed number of epochs -- i.e. you should stop training once you reach that level of accuracy.

Some notes:

It should succeed in less than 10 epochs, so it is okay to change epochs to 10, but nothing larger
When it reaches 99% or greater it should print out the string "Reached 99% accuracy so cancelling training!"
If you add any additional variables, make sure you use the same names as the ones used in the class
I've started the code for you below -- how would you finish it?

# Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms

# Constants definition

In [2]:
EPOCHS = 10
IMAGE_SIZE = 28
# Keras's default learning rate is 1e-3
LEARNING_RATE = 1e-3
# Keras's default batch size is 32
BATCH_SIZE = 32

# Providing the data

In [3]:
transform = transforms.ToTensor()
train_dataset = datasets.MNIST('../data/MNIST', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data/MNIST', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)

# Define and Compile the Neural Network

In [4]:
# CrossEntropyLoss in PyTorch assumes unnormalized values, thus Softmax should not be used
# https://stackoverflow.com/a/61438119
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(IMAGE_SIZE * IMAGE_SIZE, 512),
    nn.ReLU(),
    nn.Linear(512, 10),
#     nn.Softmax(dim=-1)
)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

# Training the Neural Network

In [5]:
for epoch in range(EPOCHS):
    print(f'Epoch {epoch+1}/{EPOCHS}')
    loss_sum = 0
    correct = 0
    for images, labels in train_loader:
        optimizer.zero_grad()
        out = model(images)
        loss = loss_fn(out, labels)
        loss.backward()
        optimizer.step()
        
        _, predicted = torch.max(out.data, 1)
        correct += (predicted == labels).sum().item()
        loss_sum += loss.item()

    acc = correct / len(train_dataset)
    print(f'loss: {loss_sum / len(train_loader):e} - acc: {acc}')
    
    if acc > 0.99:
        print('Reached 99% accuracy so cancelling training!')
        break

Epoch 1/10
loss: 2.315169e-01 - acc: 0.9324166666666667
Epoch 2/10
loss: 9.158986e-02 - acc: 0.97235
Epoch 3/10
loss: 5.555138e-02 - acc: 0.9837333333333333
Epoch 4/10
loss: 3.477582e-02 - acc: 0.9896333333333334
Epoch 5/10
loss: 2.373431e-02 - acc: 0.9923166666666666
Reached 99% accuracy so cancelling training!


# Evaluating the result

In [6]:
correct = 0
with torch.no_grad():
    for images, labels in test_loader:
        out = model(images)
        _, predicted = torch.max(out.data, 1)
        correct += (predicted == labels).sum().item()
print(f'Test accuracy: {100 * correct / len(test_dataset)}%')

Test accuracy: 97.74%
