### Imports

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable



### Load Dataset

In [2]:
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)
 
test_dataset = dsets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

### Make Dataset Iterable

In [3]:
batch_size = 100
n_iters = 6000
num_epochs = n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)
 
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
 
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

### Model

In [4]:
class MYGRUCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(MYGRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.x2h = nn.Linear(input_size, 3 * hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 3 * hidden_size, bias=bias)
               
    def forward(self, x, hidden):
        
        x = x.view(-1, x.size(1))
        
        gate_x = self.x2h(x).squeeze() 
        gate_h = self.h2h(hidden).squeeze()
        
        input_reset, input_z, input_h = gate_x.chunk(3, 1)
        hidden_reset, hidden_z, hidden_h = gate_h.chunk(3, 1)
        
        
        resetgate = torch.sigmoid(input_reset + hidden_reset)
        zgate = torch.sigmoid(input_z + hidden_z)
        hgate = torch.tanh(input_h + (resetgate * hidden_h))
        
        output = hgate*zgate + hidden-zgate # hgate + zgate * (hidden - hgate) 
        
        
        return output
    
    
class MYGRUModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, bias=True):
        super(MYGRUModel, self).__init__()
        # Hidden dimensions
        self.hidden_dim = hidden_dim
         
        # Number of hidden layers
        self.layer_dim = layer_dim
        self.gru_cell = MYGRUCell(input_dim, hidden_dim, layer_dim)   
        self.fc = nn.Linear(hidden_dim, output_dim)
     
    def forward(self, x):
        
        # Initialize hidden state with zeros
        #######################
        #  USE GPU FOR MODEL  #
        #######################
        if torch.cuda.is_available():
            h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).cuda())
        else:
            h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim))
         
       
        outs = []
        hn = h0[0,:,:]
        
        for seq in range(x.size(1)):
            hn = self.gru_cell(x[:,seq,:], hn) 
            outs.append(hn)
            
        out = outs[-1].squeeze()        
        out = self.fc(out) 
        
        return out


### INSTANTIATE MODEL CLASS

In [5]:
input_dim = 28
hidden_dim = 100
layer_dim = 3  # ONLY CHANGE IS HERE FROM ONE LAYER TO TWO LAYER
output_dim = 10
 
model = MYGRUModel(input_dim, hidden_dim, layer_dim, output_dim)
 
#######################
#  USE GPU FOR MODEL  #
#######################
 
if torch.cuda.is_available():
    model.cuda()

### Train

In [6]:

criterion = nn.CrossEntropyLoss()
learning_rate = 0.1 
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  
 
# Number of steps to unroll
seq_dim = 28 
 
iter = 0
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Load images as Variable
        #######################
        #  USE GPU FOR MODEL  #
        #######################
        if torch.cuda.is_available():
            images = images.view(-1, seq_dim, input_dim).cuda()
            labels = labels.cuda()
        else:
            images = images.view(-1, seq_dim, input_dim)

             
        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()
         
        # Forward pass to get output/logits
        # outputs.size() --> 100, 10
        outputs = model(images)
         
        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)
         
        # Getting gradients w.r.t. parameters
        loss.backward()
         
        # Updating parameters
        optimizer.step()
         
        iter += 1
         
        if iter % 500 == 0:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
                #######################
                #  USE GPU FOR MODEL  #
                #######################
                if torch.cuda.is_available():
                    images = images.view(-1, seq_dim, input_dim).cuda()

                 
                # Forward pass only to get logits/output
                outputs = model(images)
                 
                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1)
                 
                # Total number of labels
                total += labels.size(0)
                 
                # Total correct predictions
                #######################
                #  USE GPU FOR MODEL  #
                #######################
                if torch.cuda.is_available():
                    correct += (predicted.cpu() == labels.cpu()).sum()
                else:
                    correct += (predicted == labels).sum()
             
            accuracy = 100 * correct / total
             
            # Print Loss
            print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy))

Iteration: 500. Loss: 1.4569811820983887. Accuracy: 42.7400016784668
Iteration: 1000. Loss: 0.6063268184661865. Accuracy: 76.58999633789062
Iteration: 1500. Loss: 0.42163440585136414. Accuracy: 83.06999969482422
Iteration: 2000. Loss: 0.6761084198951721. Accuracy: 74.2300033569336
Iteration: 2500. Loss: 0.48743778467178345. Accuracy: 90.5199966430664
Iteration: 3000. Loss: 0.15570218861103058. Accuracy: 92.5
Iteration: 3500. Loss: 0.2698049545288086. Accuracy: 92.76000213623047
Iteration: 4000. Loss: 0.22820596396923065. Accuracy: 93.36000061035156
Iteration: 4500. Loss: 0.11312955617904663. Accuracy: 93.56999969482422
Iteration: 5000. Loss: 0.21811266243457794. Accuracy: 94.38999938964844
Iteration: 5500. Loss: 0.1607332080602646. Accuracy: 94.33000183105469
Iteration: 6000. Loss: 0.21647486090660095. Accuracy: 93.87999725341797
