# Pytorch MNIST example
#### https://github.com/pytorch/examples/blob/master/mnist/main.py
#### https://torchmetrics.readthedocs.io/en/latest/


In [79]:
!pip install torchmetrics



In [80]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchmetrics
import torchvision

In [81]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = torchvision.datasets.MNIST('../data', train=True, transform=transform, download=True)
dataset2 = torchvision.datasets.MNIST('../data', train=False, transform=transform)

In [82]:
print(dataset1)

Dataset MNIST
    Number of datapoints: 60000
    Root location: ../data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )


In [83]:
# Load the mnist data
train_kwargs = {'batch_size': 32}
test_kwargs = {'batch_size': 10}
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [84]:
# Model
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 40),
    nn.Linear(40, 30),
    nn.Linear(30, 20),
    nn.Linear(20, 10),
)

In [85]:
print(model)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=40, bias=True)
  (2): Linear(in_features=40, out_features=30, bias=True)
  (3): Linear(in_features=30, out_features=20, bias=True)
  (4): Linear(in_features=20, out_features=10, bias=True)
)


In [86]:
# Set loss function and the optimiser
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), 0.001)

In [87]:
def train(metric_train):
  model.train()
  for batch_idx, (images, labels) in enumerate(train_loader):
    # set the gradients to zero (flush accumilated gradients)
    optimizer.zero_grad()
    # Do a prediction step on a batch of images
    predictions = model(images)
    # Calculate the loss
    loss = loss_fn(predictions, labels)
    # Do a backward step to compute the model gradients
    loss.backward()
    # Update the parameters of the model
    optimizer.step()

    metric_train(predictions, labels)
  acc = metric_train.compute()
  
  return acc

In [88]:
# Validation step
def validate(metric_val):
  model.eval() # put the model into evaluation mode
  # Take the model without the gradients
  with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(test_loader):
      # do a prediction
      predictions = model(images)
      # compute accuracy
      metric_val(predictions, labels)

  acc = metric_val.compute()
  return acc

In [89]:
# initialize metric
metric_train = torchmetrics.Accuracy()
metric_val = torchmetrics.Accuracy()
# train the model
epochs = 10

for epoch in range(epochs):
  train_acc = train(metric_train)
  print(f"Train Accuracy on epoch {epoch}: {train_acc}")
  val_acc = validate(metric_val)
  print(f"Validation Accuracy on epoch {epoch}: {val_acc}")

Train Accuracy on epoch 0: 0.48401665687561035
Validation Accuracy on epoch 0: 0.6793000102043152
Train Accuracy on epoch 1: 0.6211666464805603
Validation Accuracy on epoch 1: 0.7458500266075134
Train Accuracy on epoch 2: 0.6917222142219543
Validation Accuracy on epoch 2: 0.7847333550453186
Train Accuracy on epoch 3: 0.7354999780654907
Validation Accuracy on epoch 3: 0.8092749714851379
Train Accuracy on epoch 4: 0.7650666832923889
Validation Accuracy on epoch 4: 0.8258799910545349
Train Accuracy on epoch 5: 0.7862389087677002
Validation Accuracy on epoch 5: 0.8378000259399414
Train Accuracy on epoch 6: 0.8021023869514465
Validation Accuracy on epoch 6: 0.8467857241630554
Train Accuracy on epoch 7: 0.8145020604133606
Validation Accuracy on epoch 7: 0.8539000153541565
Train Accuracy on epoch 8: 0.8244647979736328
Validation Accuracy on epoch 8: 0.8598111271858215
Train Accuracy on epoch 9: 0.8326166868209839
Validation Accuracy on epoch 9: 0.8647199869155884
