In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

# Load MNIST Dataset

In [2]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = dset.MNIST(root='./data', train=True, transform=trans, download=True)
test_set = dset.MNIST(root='./data', train=False, transform=trans)

batch_size = 128

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)


# Build MLP Model
![](images/MLP.png)

In [3]:
class MLP(nn.Module):
    def __init__(self,n_class=10):
        super(MLP, self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(28*28,64),
            nn.ReLU(inplace=True),
            nn.Linear(64,n_class)
        )
        
        """
        self.fc1 = nn.Linear(28*28,64)
        self.relu = nn.ReLU(inplace=True)
        slef.fc2 = nn.Linear(64,n_class)
        """
        
    def forward(self, x):
        x = x.view(-1,28*28)      # x:(batch_size,1,28,28) => x:(batch_size,28*28)
        logits = self.fc(x)
        return logits


# Training

In [4]:
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in xrange(10):
    # trainning
    ave_loss = 0
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        x, target = Variable(x), Variable(target)
        logits = model(x)
        loss = criterion(logits, target)
        ave_loss = ave_loss * 0.9 + loss.data[0] * 0.1
        loss.backward()
        optimizer.step()
        if (batch_idx+1) % 100 == 0 or (batch_idx+1) == len(train_loader):
            print '==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format(
                epoch, batch_idx+1, ave_loss)
    # testing
    correct_cnt, ave_loss = 0, 0
    total_cnt = 0
    for batch_idx, (x, target) in enumerate(test_loader):
        x, target = Variable(x, volatile=True), Variable(target, volatile=True)
        logits = model(x)
        loss = criterion(logits, target)
        _, pred_label = torch.max(logits.data, 1)
        total_cnt += x.data.size()[0]
        correct_cnt += (pred_label == target.data).sum()
        # smooth average
        ave_loss = ave_loss * 0.9 + loss.data[0] * 0.1
        
        if(batch_idx+1) % 100 == 0 or (batch_idx+1) == len(test_loader):
            print '==>>> epoch: {}, batch index: {}, test loss: {:.6f}, acc: {:.3f}'.format(
                epoch, batch_idx+1, ave_loss, correct_cnt * 1.0 / total_cnt)

==>>> epoch: 0, batch index: 100, train loss: 0.599243
==>>> epoch: 0, batch index: 200, train loss: 0.370671
==>>> epoch: 0, batch index: 300, train loss: 0.356554
==>>> epoch: 0, batch index: 400, train loss: 0.363393
==>>> epoch: 0, batch index: 469, train loss: 0.328055
==>>> epoch: 0, batch index: 79, test loss: 0.258991, acc: 0.904
==>>> epoch: 1, batch index: 100, train loss: 0.330634
==>>> epoch: 1, batch index: 200, train loss: 0.321183
==>>> epoch: 1, batch index: 300, train loss: 0.289736
==>>> epoch: 1, batch index: 400, train loss: 0.268071
==>>> epoch: 1, batch index: 469, train loss: 0.252494
==>>> epoch: 1, batch index: 79, test loss: 0.196113, acc: 0.928
==>>> epoch: 2, batch index: 100, train loss: 0.245743
==>>> epoch: 2, batch index: 200, train loss: 0.261259
==>>> epoch: 2, batch index: 300, train loss: 0.216066
==>>> epoch: 2, batch index: 400, train loss: 0.228394
==>>> epoch: 2, batch index: 469, train loss: 0.218366
==>>> epoch: 2, batch index: 79, test loss: 0