In [None]:
!pip3 install torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
print(torch.__version__)

# Load the data

In [None]:
# Download MNIST dataset and parse into pytorch Dataset objects.

transform = transforms.Compose([transforms.ToTensor()])

trainset = torchvision.datasets.MNIST(root='./data', download=True, train=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', download=True, train=False, transform=transform)

In [None]:
# How big is our dataset? What kind of data do we have? 

print(trainset.data.shape, ', ', trainset.data.dtype)
print(trainset.targets.shape, ', ', trainset.targets.dtype)
print()
print(testset.data.shape, ', ', testset.data.dtype)
print(testset.targets.shape, ', ', testset.targets.dtype)

In [None]:
# Look at an example
# Images are monochrome with integer pixel values between 0 and 255 (inclusive)

torch.set_printoptions(linewidth=1000)
print(trainset.data[0])

In [None]:
# Visualize some images and check their labels

import matplotlib.pyplot as plt
import numpy as np

for i in range(6):
  print(trainset.targets[i].numpy())
  plt.imshow(trainset.data[i].numpy())
  plt.show()
  print('')

# Define the model

In [None]:
# TODO: implement FF and CNN layers myself

In [None]:
# Reference: https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/01-basics/feedforward_neural_network/main.py#L37-L49

class FFNN(nn.Module):

  def __init__(self, input_size, output_size):
    super(FFNN, self).__init__()
    self.layers = [nn.Linear(input_size, 200), nn.Linear(200, 100)]
    self.output_layer = nn.Linear(100, output_size)

  def forward(self, x):
    x = x.flatten(1)
    for layer in self.layers:
      x = F.relu(layer(x))
    return self.output_layer(x)

In [None]:
# Reference: https://www.kaggle.com/code/sdelecourt/cnn-with-pytorch-for-mnist/notebook

class CNN(nn.Module):

  def __init__(self, image_shape, output_size):
    super(CNN, self).__init__()
    self.cnn_layers = [nn.Conv2d(1, 32, kernel_size=5), nn.Conv2d(32, 64, kernel_size=5)]
    f = lambda size: ((size-4)//2-4)//2
    w, h = image_shape
    self.ff_layers = [nn.Linear(f(w)*f(h)*64, 200)]
    self.output_layer = nn.Linear(200, output_size)

  def forward(self, x):
    for cnn_layer in self.cnn_layers:
      x = F.relu(F.max_pool2d(cnn_layer(x), 2))
    x = x.flatten(1)
    for ff_layer in self.ff_layers:
      x = F.relu(ff_layer(x))
    return self.output_layer(x)

In [None]:
# TODO: also demonstrate Sequential
# https://jeancochrane.com/blog/pytorch-functional-api

# Training loop

In [None]:
# TODO implement my own optimizer

In [None]:
learning_rate = 1e-3
batch_size = 32

# model = FFNN(28*28, 10)
model = CNN((28, 28), 10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# DataLoader wraps our dataset and spits it out batch by batch.
# `shuffle=True` will shuffle the order of the examples after every epoch
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

# Use GPU if available.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def accuracy(logits, target):
  argmaxs = logits.max(1).indices
  corrects = torch.eq(argmaxs, target)
  return corrects.float().mean()

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
  for i, (images, labels) in enumerate(train_loader):  
    # Move tensors to the configured device
    images = images.to(device)
    labels = labels.to(device)
    
    # Forward pass
    outputs = model(images)
    loss = criterion(outputs, labels)
    
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  
  train_logits = model(trainset.data[:10000, None] / 255)  # Training accurate
  test_logits = model(testset.data[:, None] / 255)  # Test accuracy
  loss = criterion(train_logits, trainset.targets)
  print('Epoch: %d | Train Loss: %.4f | Train Accuracy: %.2f | Test Accuracy: %.2f' % (epoch, loss.detach().item(), accuracy(train_logits, trainset.targets), accuracy(test_logits, testset.targets)))