In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# download and transform train dataset
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data', 
                                                          download=True, 
                                                          train=True,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ])), 
                                           batch_size=10, 
                                           shuffle=True)

# download and transform test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data', 
                                                          download=True, 
                                                          train=False,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ])), 
                                           batch_size=10, 
                                           shuffle=True)

class CNNClassifier(nn.Module):
    """Custom module for a simple convnet classifier"""
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.dropout = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    
    def forward(self, x):
        # input is 28x28x1
        # conv1(kernel=5, filters=10) 28x28x10 -> 24x24x10
        # max_pool(kernel=2) 24x24x10 -> 12x12x10
        
        # Do not be afraid of F's - those are just functional wrappers for modules form nn package
        # Please, see for yourself - http://pytorch.org/docs/_modules/torch/nn/functional.html
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        
        # conv2(kernel=5, filters=20) 12x12x20 -> 8x8x20
        # max_pool(kernel=2) 8x8x20 -> 4x4x20
        x = F.relu(F.max_pool2d(self.dropout(self.conv2(x)), 2))
        
        # flatten 4x4x20 = 320
        x = x.view(-1, 320)
        
        # 320 -> 50
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        
        # 50 -> 10
        x = self.fc2(x)
        
        # transform to logits
        return F.log_softmax(x)

# create classifier and optimizer objects
clf = CNNClassifier()
opt = optim.SGD(clf.parameters(), lr=0.01, momentum=0.5)

loss_history = []
acc_history = []

def train(epoch):
    clf.train() # set model in training mode (need this because of dropout)
    
    # dataset API gives us pythonic batching 
    for batch_id, (data, label) in enumerate(train_loader):
        data = Variable(data)
        target = Variable(label)
        
        # forward pass, calculate loss and backprop!
        opt.zero_grad()
        preds = clf(data)
        loss = F.nll_loss(preds, target)
        loss.backward()
        loss_history.append(loss.data[0])
        opt.step()
        
        if batch_id % 100 == 0:
            print(loss.data[0])

def test(epoch):
    clf.eval() # set model in inference mode (need this because of dropout)
    test_loss = 0
    correct = 0
    
    for data, target in test_loader:
        data = Variable(data, volatile=True) 
        target = Variable(target)
        
        output = clf(data)
        test_loss += F.nll_loss(output, target).data[0]
        pred = output.data.max(1)[1] # get the index of the max log-probability
        correct += pred.eq(target.data).cpu().sum()

    test_loss = test_loss
    test_loss /= len(test_loader) # loss function already averages over batch size
    accuracy = 100. * correct / len(test_loader.dataset)
    acc_history.append(accuracy)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        accuracy))

for epoch in range(0, 3):
    print("Epoch %d" % epoch)
    train(epoch)
    test(epoch)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
Epoch 0
tensor(2.3058)




tensor(2.0978)
tensor(1.8258)
tensor(1.6781)
tensor(0.7523)
tensor(2.0277)
tensor(0.6322)
tensor(0.5482)
tensor(0.6572)
tensor(0.3610)
tensor(0.3446)
tensor(0.5306)
tensor(0.4659)
tensor(0.2169)
tensor(0.9212)
tensor(0.6201)
tensor(0.4350)
tensor(0.2834)
tensor(0.8334)
tensor(0.5853)
tensor(0.2192)
tensor(0.2087)
tensor(0.6505)
tensor(0.7470)
tensor(0.3865)
tensor(0.1194)
tensor(0.2985)
tensor(0.8105)
tensor(0.3007)
tensor(0.9042)
tensor(0.1978)
tensor(0.1868)
tensor(0.2038)
tensor(0.3340)
tensor(0.2505)
tensor(0.8807)
tensor(0.3901)
tensor(0.0998)
tensor(0.3404)
tensor(0.1472)
tensor(0.2322)
tensor(0.0521)
tensor(0.1167)
tensor(0.0837)
tensor(0.0611)
tensor(0.4273)
tensor(0.0893)
tensor(0.1625)
tensor(0.0135)
tensor(0.0325)
tensor(0.3045)
tensor(0.1555)
tensor(0.2240)
tensor(0.3185)
tensor(0.7536)
tensor(0.5232)
tensor(0.1370)
tensor(0.1561)
tensor(2.0029)
tensor(0.0731)





Test set: Average loss: 0.0945, Accuracy: 9703/10000 (97%)

Epoch 1
tensor(0.3067)
tensor(1.2007)
tensor(0.5318)
tensor(0.3373)
tensor(1.0636)
tensor(0.3065)
tensor(0.3004)
tensor(1.8466)
tensor(0.1624)
tensor(0.8236)
tensor(0.0167)
tensor(0.1881)
tensor(0.0568)
tensor(0.2077)
tensor(0.1726)
tensor(1.6097)
tensor(0.4424)
tensor(0.5585)
tensor(0.2492)
tensor(1.1940)
tensor(0.5910)
tensor(0.0972)
tensor(0.2928)
tensor(0.0748)
tensor(0.0572)
tensor(0.8895)
tensor(0.2555)
tensor(0.1150)
tensor(0.0762)
tensor(0.2204)
tensor(0.2585)
tensor(0.1527)
tensor(0.6118)
tensor(0.3880)
tensor(0.0105)
tensor(0.7392)
tensor(0.0932)
tensor(0.3719)
tensor(0.2051)
tensor(0.1112)
tensor(0.4832)
tensor(0.2856)
tensor(0.4440)
tensor(0.0146)
tensor(0.4678)
tensor(0.0425)
tensor(0.1710)
tensor(0.0418)
tensor(0.1276)
tensor(2.1389)
tensor(0.1519)
tensor(0.1173)
tensor(0.1584)
tensor(0.1832)
tensor(0.2981)
tensor(0.0840)
tensor(0.8388)
tensor(0.0261)
tensor(0.1998)
tensor(0.9148)

Test set: Average loss: 0.0672