In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable
from torch.nn import Parameter
from torch import Tensor
import torch.nn.functional as F
import numpy as np

from cell_model import JANETCell, CIFGCell, NRUCell
from rnn_model import RNNModel
import math

In [2]:
cuda = True if torch.cuda.is_available() else False
device = 'cuda' if cuda else 'cpu'    
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 

In [3]:
#permute MNIST
rng_permute = np.random.RandomState(92916)
idx_permute = torch.from_numpy(rng_permute.permutation(784)) 
transform = transforms.Compose([transforms.ToTensor(), 
                                            transforms.Lambda(lambda x: x.view(-1)[np.array(idx_permute)].view(1, 28, 28) )])

In [4]:
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transform,
                            download=True)
 
test_dataset = dsets.MNIST(root='./data', 
                           train=False, 
                           transform=transform)
 
batch_size = 100
n_iters = 6000
num_epochs = n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)

In [5]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
 
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

In [6]:
input_dim = 28
hidden_dim = 128
layer_dim = 1  # ONLY CHANGE IS HERE FROM ONE LAYER TO TWO LAYER
output_dim = 10
cell_type = "gru"    # "gru" "lstm" "janet" "nru" "cifg" 

# model
model = RNNModel(input_dim, hidden_dim, layer_dim, output_dim, device, cell_type)

 
if torch.cuda.is_available():
    model.cuda()
     
#loss
criterion = nn.CrossEntropyLoss()
 
#optimizer
learning_rate = 0.1
 
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [7]:
# Number of steps to unroll
seq_dim = 28 

loss_list = []
iter = 0
    
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Load images as Variable
        #######################
        #  USE GPU FOR MODEL  #
        #######################
          
        if cuda:
            images = Variable(images.view(-1, seq_dim, input_dim).cuda())
            labels = Variable(labels.cuda())
        else:
            images = Variable(images.view(-1, seq_dim, input_dim))
            labels = Variable(labels)
                    
        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()
         
        outputs = model(images)

        loss = criterion(outputs, labels)

        if torch.cuda.is_available():
            loss.cuda()

        # Getting gradients w.r.t. parameters
        loss.backward()

        # Updating parameters
        optimizer.step()
        
        loss_list.append(loss.item())
        iter += 1
         
        if iter % 500 == 0:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
                #######################
                #  USE GPU FOR MODEL  #
                #######################
                if cuda:
                    images = Variable(images.view(-1, seq_dim, input_dim).cuda())
                else:
                    images = Variable(images.view(-1 , seq_dim, input_dim))
                
                # Forward pass only to get logits/output
                outputs = model(images)
                
                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1)
                 
                # Total number of labels
                total += labels.size(0)
                
                if cuda:
                    correct += (predicted.cpu() == labels.cpu()).sum()
                else:
                    correct += (predicted == labels).sum()
             
            accuracy = 100 * correct // total
             
            # Print Loss
            print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy))




Iteration: 500. Loss: 1.8285527229309082. Accuracy: 40
Iteration: 1000. Loss: 1.1653125286102295. Accuracy: 65
Iteration: 1500. Loss: 0.7396847009658813. Accuracy: 73
Iteration: 2000. Loss: 0.4139963984489441. Accuracy: 82
Iteration: 2500. Loss: 0.481403112411499. Accuracy: 81
Iteration: 3000. Loss: 0.7043436169624329. Accuracy: 85
Iteration: 3500. Loss: 0.4631727635860443. Accuracy: 88
Iteration: 4000. Loss: 0.32297036051750183. Accuracy: 88
Iteration: 4500. Loss: 0.3393784761428833. Accuracy: 88
Iteration: 5000. Loss: 0.12515024840831757. Accuracy: 91
Iteration: 5500. Loss: 0.2820318043231964. Accuracy: 91
Iteration: 6000. Loss: 0.3977452218532562. Accuracy: 92
