In [18]:
import numpy as np
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch

In [19]:
#Training Set
train_dataset=datasets.MNIST(root='./',
                            train=True,
                            transform=transforms.ToTensor(),
                            download=True)
#Testing Set
test_dataset=datasets.MNIST(root='./',
                            train=False,
                            transform=transforms.ToTensor(),
                            download=True)


In [59]:
#batch size
batch_size=64
#training set loader
train_loader=DataLoader(dataset=train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
#Testing set loader
test_loader=DataLoader(dataset=test_dataset,
                       batch_size=batch_size,
                       shuffle=True)
#print dataset
for i,data in enumerate(train_loader):
    inputs,labels=data
    print(inputs.shape)
    print(labels.shape)
    break


torch.Size([64, 1, 28, 28])
torch.Size([64])


In [56]:
#define network algorithm
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = torch.nn.LSTM(
            input_size=28,
            hidden_size=64,
            num_layers=1,
            batch_first=True
        )
        self.out = torch.nn.Linear(in_features=64,out_features=10)
        self.softmax = torch.nn.Softmax(dim=1)
    def forward(self,x):
        #Value(batch,seq_len,feature)
        x=x.view(-1,28,28)
        #h_n:[num_layers, batch, hidden_size]
        #c_n:[num_layers, batch, hidden_size]
        #output:[batch,seq_len,hidden_size]
        output,(h_n,c_n)=self.lstm(x)
        output_in_last_timestep=h_n[-1,:,:]
        x=self.out(output_in_last_timestep)
        x=self.softmax(x)
        return x
# define module
lr=0.001
model=LSTM()
#define cross entropy loss function
cross_entropy_loss=nn.CrossEntropyLoss()
#define optimizer with learning rate
optimizer=optim.Adam(model.parameters(),lr)

In [65]:
#define module training
def train():
    # module training mode (Dropout using)
    model.train()
    for i, data in enumerate(train_loader):
        #obtain a batch of inputs and labels
        inputs, labels=data
        #obtain the module predict result (64,10)
        out =model(inputs)
        #cross entropy function: out(batch,C), labels(batch)
        loss=cross_entropy_loss(out, labels)
        #initial gradient value
        optimizer.zero_grad()
        #calculate gradient 
        loss.backward()
        #renew weight
        optimizer.step()
        
#define module testing
def test():
    #module testing mode (Dropout stopping)
    model.eval()
    #Testing accuracy monitor
    correct=0
    for i,data in enumerate(test_loader):
        #obtain a batch of inputs and labels
        inputs,labels=data
        #obtain the module predict result (64,10)
        out=model(inputs)
        #obtain max value and it position
        _, predicted=torch.max(out,1)
        # predict correct amount
        correct += (predicted==labels).sum()
        print("Test accuracy:{0}".format(correct.item() / len(test_dataset)))
    #Training accuracy monitor
    correct=0
    for i,data in enumerate(train_loader):
        #obtain a batch of inputs and labels
        inputs,labels=data
        #obtain the module predict result (64,10)
        out=model(inputs)
        #obtain max value and it position
        _, predicted=torch.max(out,1)
        # predict correct amount
        correct += (predicted==labels).sum()
        print("Train accuracy:{0}".format(correct.item()/len(train_dataset)))

In [66]:
for epoch in range(10):
    print('epoch:',epoch)
    train()
    test()

epoch: 0
Test acc: 0.0063
Test acc: 0.0126
Test acc: 0.0189
Test acc: 0.0252
Test acc: 0.0316
Test acc: 0.038
Test acc: 0.0444
Test acc: 0.0508
Test acc: 0.0571
Test acc: 0.0634
Test acc: 0.0696
Test acc: 0.0759
Test acc: 0.0822
Test acc: 0.0886
Test acc: 0.0948
Test acc: 0.101
Test acc: 0.1074
Test acc: 0.1137
Test acc: 0.1198
Test acc: 0.1261
Test acc: 0.1322
Test acc: 0.1386
Test acc: 0.1449
Test acc: 0.1511
Test acc: 0.1575
Test acc: 0.1636
Test acc: 0.1699
Test acc: 0.1763
Test acc: 0.1827
Test acc: 0.189
Test acc: 0.1952
Test acc: 0.2014
Test acc: 0.2076
Test acc: 0.2138
Test acc: 0.2202
Test acc: 0.2263
Test acc: 0.2323
Test acc: 0.2387
Test acc: 0.2451
Test acc: 0.2515
Test acc: 0.2578
Test acc: 0.2638
Test acc: 0.2699
Test acc: 0.2761
Test acc: 0.2823
Test acc: 0.2886
Test acc: 0.295
Test acc: 0.301
Test acc: 0.3072
Test acc: 0.3133
Test acc: 0.3195
Test acc: 0.3256
Test acc: 0.332
Test acc: 0.3383
Test acc: 0.3446
Test acc: 0.3507
Test acc: 0.3569
Test acc: 0.3632
Test acc: 0