In [2]:
import torchvision as tv
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch as t
from torch import nn
from torch.nn import functional as F

mnist_ds_path = '../datasets'
# you can change the batch_size, epochs, batch_limited to reach different accuracy
batch_size = 32
epochs = 35
batch_limited = 1000
alpha = 0.001
classes_num = 10


train_ds = tv.datasets.MNIST(root=mnist_ds_path, train=True, transform=transforms.ToTensor())
train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)  #60000

test_ds = tv.datasets.MNIST(root=mnist_ds_path, train=False, transform=transforms.ToTensor())
test_dataloader = DataLoader(test_ds, batch_size=batch_size, shuffle=True)    #10000

print('we have %d train samples, and %d test samples' % (len(train_dataloader), len(test_dataloader)))

we have 1875 train samples, and 313 test samples


# This is an example for lenet(a cnn typical network structure) which is all constructed by torch

In [3]:
class LeNet(nn.Module):
    def __init__(self, classes_num):
        super(LeNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 8, kernel_size=(3,3), stride=1, padding=1)
        self.max_pool1 = nn.MaxPool2d((2,2))
        self.conv2 = nn.Conv2d(8, 16, kernel_size=(3,3), stride=1, padding=1)
        self.max_pool2 = nn.MaxPool2d((2,2))
        self.linear1 = nn.Linear(16*7*7, 512)
        self.linear2 = nn.Linear(512, 256)
        self.classifier = nn.Linear(256, classes_num)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.max_pool1(x)
        x = F.relu(self.conv2(x))
        x = self.max_pool2(x)
        
        x = x.view(x.shape[0], -1)
        
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        pred = self.classifier(x)
        
        return pred

In [4]:
def calc_accu(pred, gt):
    final = pred.argmax(dim=1)
    return (final == gt).float().mean()

In [5]:
def test(model, criterion):
    model.eval()
    
    loss = 0
    accuracy = 0
    round_num = 0
    for (x,y) in test_dataloader:
        pred = model(x)
        
        cur_loss = criterion(pred, y)
        cur_accuracy = calc_accu(pred, y)
        
        loss += cur_loss.item()
        accuracy += cur_accuracy.item()
        
        round_num += 1
        
    model.train()
    
    return loss/round_num, accuracy/round_num

In [6]:
lenet = LeNet(classes_num)
# loss function
criterion = nn.CrossEntropyLoss()
# no momentum means poor performance(may drop 3 percent), momentum=0.9 is good enough
# the best test precision is 0.9801
sgd = t.optim.SGD(lenet.parameters(), lr=alpha, momentum=0.9)

In [7]:
for epoch in range(epochs):
    loss = 0
    accuracy = 0
    round_num = 0
    for idx, (x,y) in enumerate(train_dataloader):
        pred = lenet(x)
        
        sgd.zero_grad()
        
        cur_loss = criterion(pred, y)
        cur_accuracy = calc_accu(pred, y)
        
        cur_loss.backward()
        sgd.step()
        
        loss += cur_loss.item()
        accuracy += cur_accuracy.item()
        
        round_num += 1
        
        if round_num == batch_limited:
            break
    
    # do test
    test_loss, test_acc = test(lenet, criterion)
    print('in epoch %d, train loss: %.4f, train acc: %.4f, test loss: %.4f, test acc: %.4f' % \
          (epoch, loss/round_num, accuracy/round_num, test_loss, test_acc))
    
        

in epoch 0, train loss: 1.6929, train acc: 0.4716, test loss: 0.5493, test acc: 0.8175
in epoch 1, train loss: 0.3602, train acc: 0.8884, test loss: 0.2802, test acc: 0.9165
in epoch 2, train loss: 0.2468, train acc: 0.9218, test loss: 0.1936, test acc: 0.9392
in epoch 3, train loss: 0.1836, train acc: 0.9443, test loss: 0.1557, test acc: 0.9515
in epoch 4, train loss: 0.1479, train acc: 0.9545, test loss: 0.1568, test acc: 0.9481
in epoch 5, train loss: 0.1206, train acc: 0.9614, test loss: 0.0995, test acc: 0.9686
in epoch 6, train loss: 0.1048, train acc: 0.9667, test loss: 0.0994, test acc: 0.9669
in epoch 7, train loss: 0.1002, train acc: 0.9684, test loss: 0.0927, test acc: 0.9705
in epoch 8, train loss: 0.0853, train acc: 0.9738, test loss: 0.1045, test acc: 0.9675
in epoch 9, train loss: 0.0784, train acc: 0.9742, test loss: 0.0700, test acc: 0.9768
in epoch 10, train loss: 0.0756, train acc: 0.9768, test loss: 0.0648, test acc: 0.9779
in epoch 11, train loss: 0.0676, train acc