In [1]:
import gzip #for zip file
import pickle  #Uses pickle to load the Python objects stored in mnist.
import numpy as np
import torch
from torch import nn, optim #nn - neural network [eg: nn.linear()], optim - optimizer [eg: optim.Adam()]
import torch.nn.functional as F  #Give acces to activation functions [eg: F.relu(), F.softmax() etc]
from torch.utils.data import TensorDataset, DataLoader  #TensorDataset: wrap input tensors and target tensors into dataset object
                                                        #DataLoader to load data batches, shuffle it and iterate over it during training

In [2]:
def load_data():
  with gzip.open('mnist.pkl.gz', 'rb') as file:
    training_set, validation_set, testing_set = pickle.load(file, encoding = 'latin')

  def tensor_data(data):
    x = torch.tensor(data[0], dtype=torch.float32) #images - [num_samples,784]
    y = torch.tensor(data[1], dtype=torch.long) #labels - [num_samples]
    return TensorDataset(x,y) #returns (images, labels)

  return tensor_data(training_set), tensor_data(validation_set), tensor_data(testing_set)

In [3]:
def initialize_parameters():
  """multiplied with 0.01 makes the weights smaller hence avoind exploding gradient
     .clone().detach() makes sure that the tensor is fresh copy and not associated with any previous computation graph
     .requires_grad_() tracks gradient so that these can be updated during training"""
  W_1 = (torch.randn(784, 128)*0.01).clone().detach().requires_grad_()
  b_1 = torch.zeros(128, requires_grad=True)
  W_2 = (torch.randn(128, 64)*0.01).clone().detach().requires_grad_()
  b_2 = torch.zeros(64, requires_grad=True)
  W_3 = (torch.randn(64, 10)*0.01).clone().detach().requires_grad_()
  b_3 = torch.zeros(10, requires_grad=True)
  return [W_1, b_1, W_2, b_2, W_3, b_3]

def forward(x, parameters):
  W_1, b_1, W_2, b_2, W_3, b_3 = parameters
  x = F.relu(x @ W_1 + b_1)
  x = F.relu(x @ W_2 + b_2)
  x = x @ W_3 + b_3
  return x

In [4]:
def accuracy(output, target):
  preds = output.argmax(dim=1)
  return (preds==target).float().mean().item()

def train(training_loader, validation_loader, parameters, epochs = 10, lr = 0.1):
  optimizer = optim.SGD(parameters, lr)
  for epoch in range(epochs):
    total_loss = 0.0
    for x,y in training_loader:

      """Forward pass and loss"""
      output = forward(x, parameters)
      loss = F.cross_entropy(output, y)

      """Backpropogation and update"""
      optimizer.zero_grad() #clear old gradients
      loss.backward()  #Compute new gradients
      optimizer.step() #updates weights

      """Total loss"""
      total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
    evaluation_model(validation_loader, parameters, name="Validation")

def evaluation_model(loader, parameters, name="Test"):
  correct=0
  total = 0
  with torch.no_grad(): #no gradient needed
    for x, y in loader:
      output = forward(x,parameters)
      pred = output.argmax(dim=1)
      correct += (pred == y).sum().item()
      total += y.size(0)
  accuracy = (correct/total) * 100
  print(f"{name} Accuracy: {accuracy:.2f}%")

training_data, validation_data, testing_data = load_data()
training_loader = DataLoader(training_data, batch_size=64, shuffle = True)
validation_loader = DataLoader(validation_data, batch_size=64)
testing_loader = DataLoader(testing_data, batch_size=64)
parameters = initialize_parameters()
train(training_loader, validation_loader, parameters, epochs=10, lr=0.1)
evaluation_model(testing_loader, parameters, name="Test")

Epoch 1, Loss: 1177.0869
Validation Accuracy: 83.46%
Epoch 2, Loss: 291.6430
Validation Accuracy: 92.98%
Epoch 3, Loss: 166.1100
Validation Accuracy: 92.92%
Epoch 4, Loss: 120.2431
Validation Accuracy: 96.05%
Epoch 5, Loss: 94.8618
Validation Accuracy: 96.54%
Epoch 6, Loss: 78.1741
Validation Accuracy: 96.66%
Epoch 7, Loss: 65.2059
Validation Accuracy: 96.90%
Epoch 8, Loss: 55.4272
Validation Accuracy: 97.13%
Epoch 9, Loss: 48.8235
Validation Accuracy: 97.55%
Epoch 10, Loss: 41.8411
Validation Accuracy: 96.78%
Test Accuracy: 96.77%
