RNN tutorial
---

In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import matplotlib.pyplot as plt
from FileReader import load_mnist

In [2]:
train_set = load_mnist('train')
train_imgs = torch.from_numpy(train_set['images']).type(torch.float)
train_labels = torch.from_numpy(train_set['labels']).type(torch.long)
print(train_imgs.shape)
print(train_labels.shape)

torch.Size([60000, 28, 28])
torch.Size([60000])


In [3]:
train_imgs_num = train_imgs.shape[0]
BATCH_SIZE = int(train_imgs_num * 0.05)


In [4]:
dataset = Data.TensorDataset(train_imgs, train_labels)
loader = Data.DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)


In [5]:
class MyLSTM(nn.Module):
    def __init__(self, in_feature):
        super(MyLSTM, self).__init__()
        
        self.LSTM = nn.LSTM(
            input_size=in_feature,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )
        
        self.outLayer = nn.Linear(in_features=64, out_features=10)
        
    def forward(self,x):
        '''
        x shape[batch_size, time_step, input_size]
        
        '''
        r_out, (h_n,h_c) = self.LSTM(x,None)
        score = self.outLayer(r_out[:,-1,:])
        return score


model = MyLSTM(28)
print(model)

MyLSTM(
  (LSTM): LSTM(28, 64, batch_first=True)
  (outLayer): Linear(in_features=64, out_features=10, bias=True)
)


In [6]:
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_func = nn.CrossEntropyLoss()


for epoch in xrange(20):
    
    for step, (batch_imgs, batch_labels) in enumerate(loader):
        
        y_pred = model(batch_imgs)
        
        loss = loss_func(y_pred, batch_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print('epoch {} step {} loss: {}'.format(epoch+1, step+1, loss.data.item()))

epoch 1 step 1 loss: 2.30822682381
epoch 1 step 2 loss: 2.21908259392
epoch 1 step 3 loss: 2.12824487686
epoch 1 step 4 loss: 2.00662589073
epoch 1 step 5 loss: 1.85864579678
epoch 1 step 6 loss: 1.7079886198
epoch 1 step 7 loss: 1.53588056564
epoch 1 step 8 loss: 1.42300570011
epoch 1 step 9 loss: 1.26240897179
epoch 1 step 10 loss: 1.17355144024
epoch 1 step 11 loss: 1.04684782028
epoch 1 step 12 loss: 1.00597512722
epoch 1 step 13 loss: 0.958687245846
epoch 1 step 14 loss: 0.896046876907
epoch 1 step 15 loss: 0.863631188869
epoch 1 step 16 loss: 0.80866932869
epoch 1 step 17 loss: 0.733355224133
epoch 1 step 18 loss: 0.741290688515
epoch 1 step 19 loss: 0.692173719406
epoch 1 step 20 loss: 0.713667452335
epoch 2 step 1 loss: 0.622769355774
epoch 2 step 2 loss: 0.657671332359
epoch 2 step 3 loss: 0.59661090374
epoch 2 step 4 loss: 0.598862886429
epoch 2 step 5 loss: 0.593277812004
epoch 2 step 6 loss: 0.561823010445
epoch 2 step 7 loss: 0.52242577076
epoch 2 step 8 loss: 0.5330000519

epoch 12 step 6 loss: 0.20428314805
epoch 12 step 7 loss: 0.235816687346
epoch 12 step 8 loss: 0.214053735137
epoch 12 step 9 loss: 0.238327622414
epoch 12 step 10 loss: 0.237817913294
epoch 12 step 11 loss: 0.225238546729
epoch 12 step 12 loss: 0.21083265543
epoch 12 step 13 loss: 0.242616146803
epoch 12 step 14 loss: 0.230630844831
epoch 12 step 15 loss: 0.219717651606
epoch 12 step 16 loss: 0.237961798906
epoch 12 step 17 loss: 0.220001995564
epoch 12 step 18 loss: 0.227932557464
epoch 12 step 19 loss: 0.231584817171
epoch 12 step 20 loss: 0.225674852729
epoch 13 step 1 loss: 0.213577210903
epoch 13 step 2 loss: 0.229484215379
epoch 13 step 3 loss: 0.206876382232
epoch 13 step 4 loss: 0.220386072993
epoch 13 step 5 loss: 0.217208713293
epoch 13 step 6 loss: 0.223251014948
epoch 13 step 7 loss: 0.210183545947
epoch 13 step 8 loss: 0.229777023196
epoch 13 step 9 loss: 0.202382072806
epoch 13 step 10 loss: 0.222196772695
epoch 13 step 11 loss: 0.221987262368
epoch 13 step 12 loss: 0.21

In [7]:
test_set = load_mnist('t10k')
test_imgs = torch.from_numpy(test_set['images']).type(torch.float)
test_labels = torch.from_numpy(test_set['labels']).type(torch.long)
print(test_imgs.shape)
print(test_labels.shape)

torch.Size([10000, 28, 28])
torch.Size([10000])


In [8]:
# test
score = model(test_imgs)

In [9]:
pred = torch.max(score, dim=1)[1]
print(pred.shape)

torch.Size([10000])


In [10]:
accuracy = 0
for idx in xrange(10000):
    accuracy += 1 if pred[idx] == test_labels[idx] else 0
print('acc: {}'.format(float(accuracy)/10000))

acc: 0.9285


In [11]:
torch.save(model.state_dict(), 'mnist_lstm_model.pkl')