In [1]:
import os
import torch
import torchvision

import torchvision.transforms as transforms
import numpy as np
import matplotlib.pylab as plt

from utils import compute_confusion_matrix, train, evaluate
from networks import RecurrentNN

# Define hyperparameters etc.

In [2]:
# pendulum
m = 1.
g = 9.81
l = 1.
dt = 1e-2

datadir = '../datasets'
sequence_len = 28
batch_size = 32
learning_rate = 0.001
num_epochs = 1
outdim = 10
indim = 28
hdim = 128
print_every = 100
num_layers = 3

# Dataset

In [3]:
# define transform to map data from a PIL.Image data type
# to a Tensor which is what pytorch uses
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])

# (down)load training and testing datasets
emnist_traindata = torchvision.datasets.EMNIST(datadir, split='mnist', download=True, transform=transform)
emnist_testdata = torchvision.datasets.EMNIST(datadir, split='mnist', train=False, download=True, transform=transform)

train_num_examples, _, _ = emnist_traindata.data.shape
test_num_examples, _, _ = emnist_testdata.data.shape

print('Training dataset has {train_num_examples}, test dataset has {test_num_examples}'.format(train_num_examples=train_num_examples, test_num_examples=test_num_examples))

Training dataset has 60000, test dataset has 10000


# Recurrent neural network

In [4]:
rnn = RecurrentNN(indim, hdim, outdim, num_layers, sequence_len)

In [5]:
params = rnn.parameters()
num_params = np.sum([np.prod(p.shape) for p in params])
print('The number of parameters in the network is: {}'.format(num_params))

The number of parameters in the network is: 87562


In [6]:
loss_fcn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)

In [7]:
# create dataloader
trainloader = torch.utils.data.DataLoader(emnist_traindata, batch_size=batch_size, shuffle=True, drop_last=True)
testloader = torch.utils.data.DataLoader(emnist_testdata, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
# make weights trainable
rnn.train()

# run training loop
training_loss = train(num_epochs, print_every, trainloader, loss_fcn, optimizer, rnn)

Epoch: 0, Iteration: 0, Loss: 2.34, Acc: 0.06
Epoch: 0, Iteration: 100, Loss: 1.78, Acc: 0.44
Epoch: 0, Iteration: 200, Loss: 1.21, Acc: 0.62
Epoch: 0, Iteration: 300, Loss: 1.09, Acc: 0.59
Epoch: 0, Iteration: 400, Loss: 1.02, Acc: 0.62
Epoch: 0, Iteration: 500, Loss: 0.92, Acc: 0.69
Epoch: 0, Iteration: 600, Loss: 0.66, Acc: 0.81
Epoch: 0, Iteration: 700, Loss: 0.39, Acc: 0.84
Epoch: 0, Iteration: 800, Loss: 0.74, Acc: 0.81
Epoch: 0, Iteration: 900, Loss: 0.38, Acc: 0.91
Epoch: 0, Iteration: 1000, Loss: 0.42, Acc: 0.88
Epoch: 0, Iteration: 1100, Loss: 0.58, Acc: 0.88
Epoch: 0, Iteration: 1200, Loss: 0.36, Acc: 0.94
Epoch: 0, Iteration: 1300, Loss: 1.31, Acc: 0.66
Epoch: 0, Iteration: 1400, Loss: 0.40, Acc: 0.88
Epoch: 0, Iteration: 1500, Loss: 0.45, Acc: 0.84
Epoch: 0, Iteration: 1600, Loss: 0.23, Acc: 0.91
Epoch: 0, Iteration: 1700, Loss: 0.58, Acc: 0.81


In [None]:
# plot training loss
plt.plot(training_loss)
plt.title('Training loss')
plt.xlabel('Iteration')
plt.ylabel('Cross Entropy Loss')

In [None]:
rnn.eval()

average_accuracy, average_loss, prediction_label_data = evaluate(testloader, loss_fcn, rnn)
    
print('Avg Loss: {loss:.2f}, Avg Acc: {acc:.2f}'.format(loss=average_loss, acc=average_accuracy))

In [None]:
confusion_matrix, mistakes = compute_confusion_matrix(prediction_label_data)
# get image with wrong prediction
mistake_idx = np.random.randint(len(mistakes))
mistake_pred, mistake_label, mistake_data = mistakes[mistake_idx]
mistake_img = mistake_data.squeeze().T

plt.subplot(121)
plt.imshow(np.log(confusion_matrix))
plt.title('Log Confusion matrix')
plt.xlabel('label')
plt.ylabel('prediction')

plt.subplot(122)
plt.imshow(mistake_img)
plt.title('Label: {label}, Prediction {pred}'.format(label=mistake_label, pred=mistake_pred))

In [None]:
# https://colah.github.io/posts/2015-08-Understanding-LSTMs/
# https://distill.pub/2016/augmented-rnns/
# Theory of gating in recurrent neural networks, Krishnamurthy et al. - https://arxiv.org/abs/2007.14823