In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as transforms

import numpy as np
from tqdm import tqdm

In [19]:
# Hyper-parameter
hidden_size = 128
num_classes = 10
num_epochs = 10
batch_size = 100
learning_rate = 0.001

input_size = 28
sequence_length = 28
num_layers = 2

In [10]:
# preprocess data
train_dataset = torchvision.datasets.MNIST(root='./data',train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data',train=False, transform=transforms.ToTensor(), download=True)

In [11]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=2)

In [29]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.num_layers =  num_layers
        self.hidden_size = hidden_size

        #self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        #self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

        # x -> (batch_size, sequence_length, input_size) because batch_size = true
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # initial hidden state size is always (num_layer, batch_size, hidden_size)
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        #out, _ = self.rnn(x, h0)
        #out, _ = self.gru(x, h0)
        out, _ = self.lstm(x, (h0,c0))
        
        # out -> (batch_size, sequence_length, hidden_size) because batch_size = true
        out = out[:, -1, :] # only the last time step
        out = self.fc(out)
        return out

net = RNN(input_size, hidden_size, num_layers, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.5)

In [30]:
#net = net.float()
net.train()
for epoch in range(num_epochs):  # loop over the dataset multiple times
    print("\nStarting epoch {}".format(epoch+1))
    
    total = 0
    running_loss = 0.0

    # to make a beautiful progress bar
    loader = tqdm(enumerate(train_loader), total=len(train_loader))
    for i, data in loader:
        # get the data points
        inputs, labels = data
        inputs, labels = inputs.reshape(-1,sequence_length, input_size).to(device), labels.to(device)
        # zero the parameter gradients (else, they are accumulated)
        optimizer.zero_grad()

        # forward the data through the network
        outputs = net(inputs)
        # calculate the loss given the output of the network and the target labels
        loss = criterion(outputs, labels)
        # calculate the gradients of the network w.r.t. its parameters
        loss.backward()
        # Let the optimiser take an optimization step using the calculated gradients
        optimizer.step()
        
        running_loss += loss
        total += outputs.size(0)

        loader.set_description("loss: {:.5f}".format(running_loss/total))

print('Finished Training')


Starting epoch 1


loss: 0.00214: 100%|██████████| 118/118 [00:05<00:00, 20.66it/s]


Starting epoch 2



loss: 0.00049: 100%|██████████| 118/118 [00:05<00:00, 20.82it/s]


Starting epoch 3



loss: 0.00030: 100%|██████████| 118/118 [00:05<00:00, 20.67it/s]


Starting epoch 4



loss: 0.00023: 100%|██████████| 118/118 [00:05<00:00, 20.86it/s]


Starting epoch 5



loss: 0.00018: 100%|██████████| 118/118 [00:05<00:00, 20.92it/s]


Starting epoch 6



loss: 0.00015: 100%|██████████| 118/118 [00:05<00:00, 20.66it/s]


Starting epoch 7



loss: 0.00013: 100%|██████████| 118/118 [00:05<00:00, 20.51it/s]


Starting epoch 8



loss: 0.00011: 100%|██████████| 118/118 [00:05<00:00, 20.88it/s]


Starting epoch 9



loss: 0.00010: 100%|██████████| 118/118 [00:05<00:00, 20.68it/s]


Starting epoch 10



loss: 0.00008: 100%|██████████| 118/118 [00:05<00:00, 20.64it/s]

Finished Training





In [31]:
net.eval()
class Accuracy:
    """A class to keep track of the accuracy while training"""
    def __init__(self):
        self.correct = 0
        self.total = 0
        
    def reset(self):
        """Resets the internal state"""
        self.correct = 0
        self.total = 0
        
    def update(self, output, labels):
        """
        Updates the internal state to later compute the overall accuracy
        
        output: the output of the network for a batch
        labels: the target labels
        """
        _, predicted = torch.max(output.data, 1) # predicted now contains the predicted class index/label
        
        self.total += labels.size(0)
        self.correct += (predicted == labels).sum().item() # .item() gets the number, not the tensor

    def compute(self):
        return self.correct/self.total

accuracy = Accuracy()

accuracy.reset()
# Gradients are calculated on the forward pass for every iteration.
# As we do not need gradients now, we can disable the calculation.
with torch.no_grad():
    for data in tqdm(train_loader):
        # get the data points
        inputs, labels = data
        inputs, labels = inputs.reshape(-1,sequence_length, input_size).to(device), labels.to(device)
        # forward the data through the network
        outputs = net(inputs)
        
        accuracy.update(outputs, labels)

print("Training Accuracy: {:.2f}%".format(100 * accuracy.compute()))

accuracy.reset()        
with torch.no_grad():
    for data in tqdm(val_loader):
        # get the data points
        inputs, labels = data
        inputs, labels = inputs.reshape(-1,sequence_length, input_size).to(device), labels.to(device)
        # forward the data through the network
        outputs = net(inputs)
        
        accuracy.update(outputs, labels)
        
print("\nTesting Accuracy: {:.2f}%".format(100 * accuracy.compute()))

100%|██████████| 118/118 [00:05<00:00, 22.72it/s]


Training Accuracy: 98.54%


100%|██████████| 20/20 [00:00<00:00, 20.72it/s]


Testing Accuracy: 98.01%



