In [6]:
import numpy as np
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch

In [7]:
#Training Set
train_dataset=datasets.MNIST(root='./',
                            train=True,
                            transform=transforms.ToTensor(),
                            download=True)
#Testing Set
test_dataset=datasets.MNIST(root='./',
                            train=False,
                            transform=transforms.ToTensor(),
                            download=True)


In [11]:
#batch size
batch_size=64
#training set loader
train_loader=DataLoader(dataset=train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
#Testing set loader
test_loader=DataLoader(dataset=test_dataset,
                       batch_size=batch_size,
                       shuffle=True)
#print dataset
for i,data in enumerate(train_loader):
    inputs,labels=data
    print(inputs.shape)
    print(labels.shape)
    break
    
print(labels)
print(len(train_loader))
print(len(test_loader))

torch.Size([64, 1, 28, 28])
torch.Size([64])
tensor([8, 8, 4, 1, 0, 8, 8, 3, 1, 6, 7, 9, 6, 5, 2, 0, 5, 9, 5, 7, 1, 3, 1, 2,
        7, 7, 5, 2, 2, 3, 9, 1, 5, 7, 1, 2, 2, 2, 0, 2, 6, 4, 5, 7, 5, 4, 6, 0,
        4, 8, 0, 1, 1, 5, 4, 9, 5, 7, 1, 0, 0, 4, 0, 6])
938
157


In [14]:
#define network algorithm
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.convolution1=nn.Sequential(nn.Conv2d(1,64,5,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
        self.convolution2=nn.Sequential(nn.Conv2d(64,128,5,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
        self.fullcon1=nn.Sequential(nn.Linear(128*7*7,2048),nn.Dropout(p=0.5),nn.ReLU())
        self.fullcon2=nn.Sequential(nn.Linear(2048,10),nn.Softmax(dim=1))
   
    def forward(self,x):
        #MNIST Size [64,1,28,28]
        x=self.convolution1(x)
        x=self.convolution2(x)
        #reshape dimension (64,128,7,7)
        x=x.view(x.size()[0],-1)
        x=self.fullcon1(x)
        x=self.fullcon2(x)
        return x
# define module
model=CNN()
#define cross entropy loss function
cross_entropy_loss=nn.CrossEntropyLoss()
#define optimizer with learning rate
optimizer=optim.Adam(model.parameters(),lr=0.001)

In [15]:
#define module training
def train():
    # module training mode (Dropout using)
    model.train()
    for i, data in enumerate(train_loader):
        #obtain a batch of inputs and labels
        inputs, labels=data
        #obtain the module predict result (64,10)
        out =model(inputs)
        #cross entropy function: out(batch,C), labels(batch)
        loss=cross_entropy_loss(out, labels)
        #initial gradient value
        optimizer.zero_grad()
        #calculate gradient 
        loss.backward()
        #renew weight
        optimizer.step()
#define module testing
def test():
    #module testing mode (Dropout stopping)
    model.eval()
    #Testing accuracy monitor
    correct=0
    for i,data in enumerate(test_loader):
        #obtain a batch of inputs and labels
        inputs,labels=data
        #obtain the module predict result (64,10)
        out=model(inputs)
        #obtain max value and it position
        _, predicted=torch.max(out,1)
        # predict correct amount
        correct += (predicted==labels).sum()
        print("Test accuracy:{0}".format(correct.item()/len(test_dataset)))
    #Training accuracy monitor
    correct=0
    for i,data in enumerate(train_loader):
        #obtain a batch of inputs and labels
        inputs,labels=data
        #obtain the module predict result (64,10)
        out=model(inputs)
        #obtain max value and it position
        _, predicted=torch.max(out,1)
        # predict correct amount
        correct += (predicted==labels).sum()
        print("Train accuracy:{0}".format(correct.item()/len(train_dataset)))

In [16]:
for epoch in range(20):
    print('epoch:',epoch)
    train()
    test()

epoch: 0
Test accuracy:0.0064
Test accuracy:0.0126
Test accuracy:0.0188
Test accuracy:0.0251
Test accuracy:0.0314
Test accuracy:0.0377
Test accuracy:0.0439
Test accuracy:0.0502
Test accuracy:0.0562
Test accuracy:0.0625
Test accuracy:0.0686
Test accuracy:0.0748
Test accuracy:0.0812
Test accuracy:0.0872
Test accuracy:0.0936
Test accuracy:0.0999
Test accuracy:0.1062
Test accuracy:0.1126
Test accuracy:0.119
Test accuracy:0.1254
Test accuracy:0.1317
Test accuracy:0.1378
Test accuracy:0.1442
Test accuracy:0.1505
Test accuracy:0.1567
Test accuracy:0.1629
Test accuracy:0.1691
Test accuracy:0.1754
Test accuracy:0.1816
Test accuracy:0.1878
Test accuracy:0.194
Test accuracy:0.2003
Test accuracy:0.2065
Test accuracy:0.2128
Test accuracy:0.2191
Test accuracy:0.2251
Test accuracy:0.2314
Test accuracy:0.2377
Test accuracy:0.244
Test accuracy:0.2502
Test accuracy:0.2565
Test accuracy:0.2628
Test accuracy:0.269
Test accuracy:0.2752
Test accuracy:0.2813
Test accuracy:0.2876
Test accuracy:0.294
Test accu