In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
import os


In [2]:
num_epochs = 5
batch_size = 100
learning_rate = 0.001

train_dataset  = dsets.MNIST(root='./data/',
                            train=True,
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data/',
                          train=False,
                          transform=transforms.ToTensor())

train_loader  = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                         batch_size=batch_size,
                                         shuffle=False)

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2), #1 * 28 * 28 -> 16 * 28 * 28
            nn.BatchNorm2d(16), #16 * 28 * 28
            nn.ReLU(),                 #16 * 28 * 28
            nn.MaxPool2d(2))       # 16 * 14 * 14
        self.layer2  = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2), # 16 * 14 * 14 -> 32 * 14 * 14
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))       # 32 * 7 * 7
        self.fc = nn.Linear(32 * 7 * 7, 10)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out
    
cnn = CNN()

In [4]:
if os.path.isfile('pkl/cnn.pkl'):
    cnn.load_state_dict(torch.load('pkl/cnn.pkl'))

else:
    criterion = nn.CrossEntropyLoss() # Loss
    optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
    
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = Variable(images)
            labels = Variable(labels)
            
            optimizer.zero_grad()
            outputs = cnn(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            if (i+1) % 100 == 0:
                print('Epoch [%d/%d], lter [%d/%d] Loss: %.4f'
                     %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.data))
                if not os.path.isfile('pkl/cnn,pkl'):
                    torch.save(cnn.state_dict(), 'pkl/cnn.pkl')


Epoch [1/5], lter [100/600] Loss: 0.1962
Epoch [1/5], lter [200/600] Loss: 0.1328
Epoch [1/5], lter [300/600] Loss: 0.0760
Epoch [1/5], lter [400/600] Loss: 0.1075
Epoch [1/5], lter [500/600] Loss: 0.2402
Epoch [1/5], lter [600/600] Loss: 0.0304
Epoch [2/5], lter [100/600] Loss: 0.0564
Epoch [2/5], lter [200/600] Loss: 0.0171
Epoch [2/5], lter [300/600] Loss: 0.0214
Epoch [2/5], lter [400/600] Loss: 0.0145
Epoch [2/5], lter [500/600] Loss: 0.0393
Epoch [2/5], lter [600/600] Loss: 0.0162
Epoch [3/5], lter [100/600] Loss: 0.0066
Epoch [3/5], lter [200/600] Loss: 0.0461
Epoch [3/5], lter [300/600] Loss: 0.0303
Epoch [3/5], lter [400/600] Loss: 0.0459
Epoch [3/5], lter [500/600] Loss: 0.0069
Epoch [3/5], lter [600/600] Loss: 0.0063
Epoch [4/5], lter [100/600] Loss: 0.0094
Epoch [4/5], lter [200/600] Loss: 0.0027
Epoch [4/5], lter [300/600] Loss: 0.0245
Epoch [4/5], lter [400/600] Loss: 0.0058
Epoch [4/5], lter [500/600] Loss: 0.0344
Epoch [4/5], lter [600/600] Loss: 0.0779
Epoch [5/5], lte

In [5]:
cnn.eval()
correct = 0
total = 0
for images, labels in test_loader:
    images = Variable(images)
    outputs = cnn(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
    
print('test Accuracy 10000 test images  = %f %%' % (100 * correct/total))

test Accuracy 10000 test images  = 98.000000 %
