# RNN on MNIST
<img src="./LSTM.png">

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import numpy as np


In [2]:
# 1. Hyper Parameters
input_size = 28
sequence_size = 28
hidden_size = 128
num_layers = 1
num_classes = 10

learning_rate = 0.01
batch_size = 1
ephoc_size = 2


In [3]:
# 2. Preparing datasets
    # MNIST Dataset (Images and Labels)
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

    # Dataset Loader (Input Pipline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [4]:
# 3. Build a model
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fnn = nn.Linear(hidden_size, num_classes)

    def forward(self, x):

        out, _ = self.rnn(x)
        #print(out, out.size())
        out_tmp = out[:, -1, :] 
        #print(out_tmp)   # 1x128 takes the last hidden layer  
        output = self.fnn(out_tmp) #output = 1x10
        #print(output)
        
        return output

In [5]:
# 4. Generate the model
model = RNNModel(input_size, hidden_size, num_layers, num_classes)

In [6]:
# 5. Set Loss and Optimizer
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
# 6. Train
for ephoc in range(ephoc_size):
    for idx, (images, labels) in enumerate(train_loader):
        # convert dataset as for Pytorch type
        images = Variable(images.view(-1, sequence_size, input_size))
        labels = Variable(labels)
        
        # Forward, Backward and Gradient decent
        optimizer.zero_grad()
        output = model(images)
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()
        if idx%100 == 0:
            print("loss:", loss.item())

loss: 2.3811349868774414
loss: 6.629345893859863
loss: 2.1270318031311035
loss: 1.9513318538665771
loss: 2.067624568939209
loss: 4.720172882080078
loss: 3.350475311279297
loss: 3.827902317047119
loss: 1.5780283212661743
loss: 1.9668445587158203
loss: 4.359375953674316
loss: 5.317937850952148
loss: 0.5633292198181152
loss: 1.1902670860290527
loss: 2.626038074493408
loss: 3.2573401927948
loss: 3.3966360092163086
loss: 2.602323293685913
loss: 2.7518341541290283
loss: 5.73361873626709
loss: 4.8019304275512695
loss: 2.0918073654174805
loss: 3.300025463104248
loss: 3.2501797676086426
loss: 1.8399460315704346
loss: 4.838371276855469
loss: 2.3016364574432373
loss: 2.0933027267456055
loss: 2.919473171234131
loss: 3.779080390930176
loss: 3.208118438720703
loss: 1.9183874130249023
loss: 1.7078962326049805
loss: 0.7306676506996155
loss: 0.7928506731987
loss: 4.043917655944824
loss: 2.164327383041382
loss: 0.7351899147033691
loss: 1.8354421854019165
loss: 2.65408992767334
loss: 4.848973274230957
lo

In [None]:
# 7. Test
correct = 0
total = 0
for images, labels in test_loader:
    images = Variable(images.view(-1, sequence_size, input_size))
    
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += len(predicted)
    correct += (predicted == labels).sum()
    
print("accuracy:", correct.item()/total)