In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

# Load MNIST Dataset

In [2]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = dset.MNIST(root='./data', train=True, transform=trans, download=True)
test_set = dset.MNIST(root='./data', train=False, transform=trans)

batch_size = 128

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)


# Build CNN Model
![](images/CNN.png)

In [3]:
class LeNet(nn.Module):
    def __init__(self,n_class=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels = 1,
            out_channels = 20,
            kernel_size = 5
        )
        self.conv2 = nn.Conv2d(
            in_channels = 20,
            out_channels = 50,
            kernel_size = 5
        ) 
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, n_class)
    def forward(self, x):
        x = F.relu(self.conv1(x))   # x:[batch_size,1,28,28] => x:[batch_size,20, 24, 24]
        x = F.max_pool2d(x, 2, 2)   # x:[batch_size,20,24,24] => x:[batch_size,20, 12, 12]
        x = F.relu(self.conv2(x))   # x:[batch_size,20,12,12] => x:[batch_size,50, 8, 8]
        x = F.max_pool2d(x, 2, 2)   # x:[batch_size,50,8,8] => x:[batch_size,50, 4, 4]
        x = x.view(-1, 4*4*50)      # x:[batch_size,50,4,4] => x:[batch_size,50*4*4]
        x = F.relu(self.fc1(x))     # x:[batch_size,50*4*4] => x:[batch_size,500]
        x = self.fc2(x)             # x:[batch_size,500] => x:[batch_size,10]
        return x


# Training

In [4]:
model = LeNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in xrange(10):
    # trainning
    ave_loss = 0
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        x, target = Variable(x), Variable(target)
        logits = model(x)
        exit()
        loss = criterion(logits, target)
        ave_loss = ave_loss * 0.9 + loss.data[0] * 0.1
        loss.backward()
        optimizer.step()
        if (batch_idx+1) % 100 == 0 or (batch_idx+1) == len(train_loader):
            print '==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format(
                epoch, batch_idx+1, ave_loss)
    # testing
    correct_cnt, ave_loss = 0, 0
    total_cnt = 0
    for batch_idx, (x, target) in enumerate(test_loader):
        x, target = Variable(x, volatile=True), Variable(target, volatile=True)
        logits = model(x)
        loss = criterion(logits, target)
        _, pred_label = torch.max(logits.data, 1)
        total_cnt += x.data.size()[0]
        correct_cnt += (pred_label == target.data).sum()
        # smooth average
        ave_loss = ave_loss * 0.9 + loss.data[0] * 0.1
        
        if(batch_idx+1) % 100 == 0 or (batch_idx+1) == len(test_loader):
            print '==>>> epoch: {}, batch index: {}, test loss: {:.6f}, acc: {:.3f}'.format(
                epoch, batch_idx+1, ave_loss, correct_cnt * 1.0 / total_cnt)

==>>> epoch: 0, batch index: 100, train loss: 0.203049
==>>> epoch: 0, batch index: 200, train loss: 0.092486
==>>> epoch: 0, batch index: 300, train loss: 0.076701
==>>> epoch: 0, batch index: 400, train loss: 0.078407
==>>> epoch: 0, batch index: 469, train loss: 0.062016
==>>> epoch: 0, batch index: 79, test loss: 0.028026, acc: 0.986
==>>> epoch: 1, batch index: 100, train loss: 0.043321
==>>> epoch: 1, batch index: 200, train loss: 0.057436
==>>> epoch: 1, batch index: 300, train loss: 0.051468
==>>> epoch: 1, batch index: 400, train loss: 0.047724
==>>> epoch: 1, batch index: 469, train loss: 0.034381
==>>> epoch: 1, batch index: 79, test loss: 0.024080, acc: 0.989
==>>> epoch: 2, batch index: 100, train loss: 0.035867
==>>> epoch: 2, batch index: 200, train loss: 0.035561
==>>> epoch: 2, batch index: 300, train loss: 0.022927
==>>> epoch: 2, batch index: 400, train loss: 0.037693
==>>> epoch: 2, batch index: 469, train loss: 0.046425
==>>> epoch: 2, batch index: 79, test loss: 0