In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

# Load MNIST Dataset

In [2]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = dset.MNIST(root='./data', train=True, transform=trans, download=True)
test_set = dset.MNIST(root='./data', train=False, transform=trans)

batch_size = 128

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)


# Build RNN Model
![](images/RNN.png)

In [3]:
class RNN(nn.Module):
    def __init__(self,input_size=28,hidden_size=64,n_class=10):
        super(RNN, self).__init__()
        self.RNN = nn.GRU(
            input_size = input_size,
            hidden_size = hidden_size,
            batch_first = True
        )
        self.fc = nn.Linear(hidden_size,10)
    
    def forward(self, x):
        x = x.squeeze()       # x:(batch_size,1,28,28) => x:(batch_size,28,28)
        out, _ = self.RNN(x)  # x:(batch_size,28,28) => out:(batch_size,28,hidden_size)  
        # get last hidden
        out = out[:, -1, :]   # out:(batch_size,hidden_size)
        logits = self.fc(out)
        return logits


# Training

In [4]:
model = RNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in xrange(10):
    # trainning
    ave_loss = 0
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        x, target = Variable(x), Variable(target)
        logits = model(x)
        loss = criterion(logits, target)
        ave_loss = ave_loss * 0.9 + loss.data[0] * 0.1
        loss.backward()
        optimizer.step()
        if (batch_idx+1) % 100 == 0 or (batch_idx+1) == len(train_loader):
            print '==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format(
                epoch, batch_idx+1, ave_loss)
    # testing
    correct_cnt, ave_loss = 0, 0
    total_cnt = 0
    for batch_idx, (x, target) in enumerate(test_loader):
        x, target = Variable(x, volatile=True), Variable(target, volatile=True)
        logits = model(x)
        loss = criterion(logits, target)
        _, pred_label = torch.max(logits.data, 1)
        total_cnt += x.data.size()[0]
        correct_cnt += (pred_label == target.data).sum()
        # smooth average
        ave_loss = ave_loss * 0.9 + loss.data[0] * 0.1
        
        if(batch_idx+1) % 100 == 0 or (batch_idx+1) == len(test_loader):
            print '==>>> epoch: {}, batch index: {}, test loss: {:.6f}, acc: {:.3f}'.format(
                epoch, batch_idx+1, ave_loss, correct_cnt * 1.0 / total_cnt)

==>>> epoch: 0, batch index: 100, train loss: 1.481127
==>>> epoch: 0, batch index: 200, train loss: 0.799926
==>>> epoch: 0, batch index: 300, train loss: 0.561906
==>>> epoch: 0, batch index: 400, train loss: 0.437470
==>>> epoch: 0, batch index: 469, train loss: 0.402823
==>>> epoch: 0, batch index: 79, test loss: 0.316786, acc: 0.889
==>>> epoch: 1, batch index: 100, train loss: 0.330944
==>>> epoch: 1, batch index: 200, train loss: 0.280307
==>>> epoch: 1, batch index: 300, train loss: 0.234186
==>>> epoch: 1, batch index: 400, train loss: 0.226141
==>>> epoch: 1, batch index: 469, train loss: 0.215579
==>>> epoch: 1, batch index: 79, test loss: 0.174038, acc: 0.940
==>>> epoch: 2, batch index: 100, train loss: 0.211807
==>>> epoch: 2, batch index: 200, train loss: 0.179505
==>>> epoch: 2, batch index: 300, train loss: 0.200663
==>>> epoch: 2, batch index: 400, train loss: 0.155425
==>>> epoch: 2, batch index: 469, train loss: 0.153705
==>>> epoch: 2, batch index: 79, test loss: 0