In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import time
import os

In [None]:
"""
Today we will train a neural network to classify images of handwritten digits using the MNIST dataset.
MNIST dataset can be loaded directly from PyTorch, as can many datasets.

https://pytorch.org/vision/stable/datasets.html


What do we need to train a model?
- 
-
-
-
-
-
-

"""

In [None]:
# transforms
mnist_transforms = transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize([0.5], [0.5])
                                        ])

In [None]:
# load train/validation MNIST datasets
mnist_train_data = datasets.MNIST('/home/jovyan/MNIST/', train=True, download=True, transform=mnist_transforms)
mnist_val_data = datasets.MNIST('/home/jovyan/MNIST/', train=False, download=True, transform=mnist_transforms)

In [None]:
# define some hypers
BATCH_SIZE = 32
LEARNING_RATE = 0.01
NUM_EPOCHS = 5
MODEL_WEIGHT_SAVE_PATH = '/home/jovyan/best-mnist-val-weights.pth'

In [None]:
# data loader
train_dataloader = torch.utils.data.DataLoader(mnist_train_data,
                                              batch_size = BATCH_SIZE,
                                              shuffle = True
                                              )

val_dataloader = torch.utils.data.DataLoader(mnist_val_data,
                                              batch_size = BATCH_SIZE,
                                              shuffle = False
                                              )

In [None]:
# define a neural network
class The_best_neural_network(nn.Module):
    def __init__(self, embedding_dim = 128):
        super(The_best_neural_network, self).__init__()
        
        # NOTE: input shape is (1, 28, 28) but we want to map it to 784 (for simple MLP)
        self.embedding_dim = embedding_dim
        
        # TODO: implement this

        
    def forward(self, x):

        
        # TODO
        return None

In [None]:
# define a training epoch
def do_training_epoch(model, dataloader, loss_func, optimizer):
    # init some metrics
    num_instances = 0
    num_correct = 0
    running_loss = 0.0
    
    
    # set the model to be in "train" model
    model.train()
    
    # iterate through the dataloader, batch by batch
    
    for i, batch in enumerate(dataloader):
        
        x, labels = batch
        
        # pass to GPU?
        
        
        # zero the gradients
        optimizer.zero_grad()
        
        # make sure we are tracking gradients from here on out
        with torch.set_grad_enabled(True):
            # pass data through the network
            output = model(x)
            
            # compute the loss
            loss = loss_func(output, labels)
            
            # call back-prop
            loss.backward()
            
            # do a step of gradient descent
            optimizer.step()
        
        
        # now let's update our metrics
        with torch.no_grad():
            _, preds = torch.max(output, 1)
            running_loss += loss.item() * x.size(0)
            num_correct += torch.sum(preds == labels.data)
            num_instances += x.size(0)
    mean_loss = running_loss / num_instances
    mean_accuracy = num_correct / num_instances
    
    return mean_accuracy, mean_loss

In [None]:
# define a validation epoch
def do_validation_epoch(model, dataloader, loss_func):
    # init some metrics
    num_instances = 0
    num_correct = 0
    running_loss = 0.0
    
    # set the model to be in "evaluation" model
    model.eval()
    
    # iterate through the dataloader, batch by batch
    
    for i, batch in enumerate(dataloader):
        
        x, labels = batch
        
        # pass to GPU?
        

        # make sure we are tracking gradients from here on out
        with torch.no_grad():
            # pass data through the network
            output = model(x)
            
            # compute the loss
            loss = loss_func(output, labels)

            # now let's update our metrics
            _, preds = torch.max(output, 1)
            running_loss += loss.item() * x.size(0)
            num_correct += torch.sum(preds == labels.data)
            num_instances += x.size(0)
    mean_loss = running_loss / num_instances
    mean_accuracy = num_correct / num_instances
    
    return mean_accuracy, mean_loss

In [None]:
# build model
net = The_best_neural_network()

print("Model structure: ", net)

params = filter(lambda p: p.requires_grad, net.parameters())
num_params = sum([np.prod(p.size()) for p in params])
print("Model parameters: ", num_params)



In [None]:
# build optimizer
optimizer = optim.SGD(net.parameters(), 
                      lr=LEARNING_RATE, 
                      momentum=0.9)

In [None]:
# build loss function
loss_func = nn.CrossEntropyLoss()

In [None]:
# keeping track of metrics

best_val_acc = 0.0
training_losses = []
training_accs = []
val_losses = []
val_accs = []

In [None]:
# train model, 
for epoch in range(NUM_EPOCHS):
    print("Let's do it up, epoch number: ", (epoch+1), " of: ", NUM_EPOCHS)
    
    # train epoch
    epoch_acc, epoch_loss = do_training_epoch(net, train_dataloader, loss_func, optimizer)
    
    print("Training loss: ", epoch_loss, " and accuracy: ", epoch_acc.item())
    
    # update metrics
    training_losses.append(epoch_loss)
    training_accs.append(epoch_acc.item())
    
    # val epoch
    epoch_acc, epoch_loss = do_validation_epoch(net, val_dataloader, loss_func)
    
    print("Validation loss: ", epoch_loss, " and accuracy: ", epoch_acc.item())
    
    # update metrics
    val_losses.append(epoch_loss)
    val_accs.append(epoch_acc.item())
    
    # is this the best epoch yet? if so, let's save the model
    if epoch_acc > best_val_acc:
        best_val_acc = epoch_acc
        state_dict = {'weights': net.state_dict(),
                     'epoch': epoch,
                      'val_acc': epoch_acc.item()
                     }
        torch.save(state_dict, MODEL_WEIGHT_SAVE_PATH)

In [None]:
# evaluation plots

plt.plot(training_accs, color='r', label='Training')
plt.plot(val_accs, color='k', label="Validation")
plt.title("Accuracy plots")
plt.ylabel("Acc.")
plt.xlabel("Epoch")
plt.legend()
plt.show()
plt.close()

plt.plot(training_losses, color='r', label='Training')
plt.plot(val_losses, color='k', label="Validation")
plt.title("Loss curves")
plt.ylabel("CE Loss")
plt.xlabel("Epoch")
plt.legend()
plt.show()
plt.close()

In [None]:
# load best model
my_model = The_best_neural_network()
loaded_state_dict = torch.load(MODEL_WEIGHT_SAVE_PATH)
my_model.load_state_dict(loaded_state_dict['weights'])

In [None]:
# let's look at a batch ourselves
print(my_model.fc.weight.shape)
my_model.fc.bias

batch = next(iter(val_dataloader))
x, labels = batch
print(x.shape)
x = x.view(-1, 784)
print(x.shape)

out = my_model(x)
print(out.shape)
print(out[:4])

In [None]:
# these are the logits, what about softmax "probability" scores
softmax = F.softmax(out)
print(softmax[:4])

In [None]:
_, preds = torch.max(softmax, 1)
print(preds)
print(labels)
print((preds == labels).sum().item())

In [None]:
# pass data through layer by layer