In [1]:
from torchvision import datasets, transforms
import torch

In [2]:
# download and load the training and test datasets
train_dataset = datasets.MNIST(root="./data", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transforms.ToTensor(), download=True)

In [3]:
batch_size = 100

# only the train dataset gets shuffled
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [4]:
import torch.nn as nn

class MNISTModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(784, 256)
    # dropout of 30 % for each step means that the neurons are randomly zeroed
    # to force the model to pay attention to overarching themes instead of
    # intricacies from a few neurons (overfitting)
    self.dropout1 = nn.Dropout(0.3)
    self.linear2 = nn.Linear(256, 128)
    self.linear3 = nn.Linear(128, 64)
    self.linear4 = nn.Linear(64, 10)
    # 64 = number of neurons, 10 = number of output neurons
    # representative of classes 0 - 9

  def forward(self, x):
    # reshapes input tensor x into 2d tensor
    # (-1 = automatically calculate number of rows)
    x = x.view(-1, 784)
    x = torch.relu(self.linear1(x)) # activation function
    x = self.dropout1(x)
    x = torch.relu(self.linear2(x))
    x = torch.relu(self.linear3(x))
    x = self.linear4(x)
    return x

model = MNISTModel()

In [5]:
import torch.optim as optim
import torch.nn.functional as F

# loss Function measures how well the model's predctions
# match the true labels
criterion = nn.CrossEntropyLoss()

# Optimizer updates the model's weights during training
optimizer = optim.Adam(model.parameters(), lr=0.006)

In [None]:
num_epochs = 6 # iterations over the entire dataset
for epoch in range(num_epochs):
  for batch_idx, (data,targets) in enumerate(train_loader):
    # forward pass (input images from train_loader are passed through
    # the model for predictions)
    outputs = model(data)
    # loss = difference between predictions and true labels
    loss = criterion(outputs, targets)

    # resetting of the gradients so that this batch's weight updates
    # from optimizer are not influenced by previous batch's gradients
    optimizer.zero_grad()
    # backward propagation/pass to calculate gradients
    loss.backward()
    optimizer.step() # this is the optimization of the weights using gradients

    # print statistics
    if batch_idx % 100 == 0:
        print(f'Epoch: {epoch+1}, Batch: {batch_idx+1}, Loss: {loss.item():.4f}')

Epoch: 1, Batch: 1, Loss: 2.3144
Epoch: 1, Batch: 101, Loss: 0.3602
Epoch: 1, Batch: 201, Loss: 0.3078
Epoch: 1, Batch: 301, Loss: 0.4019
Epoch: 1, Batch: 401, Loss: 0.2588


In [None]:
model.eval()
test_loss = 0
correct = 0
# for eval, loop over test set without updating gradients
with torch.no_grad():
  for data, targets in test_loader:
    # compute loss by comparing predictions with true labels
    outputs = model(data)
    test_loss += criterion(outputs, targets).item()

    # [outputs.data] yields the raw output tensors from model (one tensor
    # per input in the batch)
    # [torch.max] returns two tensors: 1. the maximum value in each output
    # tensor (across 10 class scores), and the index of that maximum value
    # (which corresponds to the predicted class label, stored in [predicted])
    _, predicted = torch.max(outputs.data, 1)
    # values essentially represent the confidence in the respectively highest
    # confidence class prediction
    # the loss function optimizes for the softmax probability of te true class
    # to be as high as possible, and the 9 other classes' to be as low as
    # possible
    # print(torch.max(outputs.data, 1))
    correct += (predicted == targets).sum().item()

test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {acc:.2f}%')