In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

In [3]:
# 1 create model
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential( # (1, 28, 28)
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2 ),      # (16, 28, 28)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),    # (16, 14, 14)
        )
        self.conv2 = nn.Sequential(  # (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),  # (32, 14, 14)
            nn.ReLU(),
            nn.MaxPool2d(2),  # (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)   # fully connected layer, output 10 classes

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)   # (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output

net = CNN()
print(net)  # net architecture

CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Linear(in_features=1568, out_features=10, bias=True)
)


In [4]:
# 2 load data
import torchvision
import torch.utils.data as Data

train_data = torchvision.datasets.MNIST(
    root='./dataset/', 
    train=True,
    transform=torchvision.transforms.ToTensor())

test_data = torchvision.datasets.MNIST(
    root='./dataset/', 
    train=False)

train_data.data[0]

# batch train data
train_loader = Data.DataLoader(dataset=train_data, batch_size=32, shuffle=True)

# preprocess test data
# size: (n, 28, 28) -> (n, 1, 28, 28)
# value [0, 255] -> [0, 1]
with torch.no_grad():
    test_x = Variable(torch.unsqueeze(test_data.data, dim=1)).type(torch.FloatTensor) / 255.0
test_y = test_data.targets

In [6]:
# 3 train and evaluate model
fun_loss = nn.CrossEntropyLoss() # cross entropy loss
optimizer = torch.optim.SGD(net.parameters(), lr=0.02) # SGD Optimizer

def evaluate(x, y):
    '''
    x: (n, 1, 28, 28)
    y: (n, 10)
    '''
    out = net(x)
    y_ = torch.max(out, 1)[1].detach() # max() return (value, index)
    accuracy = sum(y_==y) / y.size(0)
    return accuracy.item(), y_


# training and testing
for epoch in range(3):
    for step, (x, y) in enumerate(train_loader): 
        b_x = Variable(x)   # batch x
        b_y = Variable(y)   # batch y

        output = net(b_x)               # ann output
        loss = fun_loss(output, b_y)    # cross entropy loss
        optimizer.zero_grad()           # clear gradients for this training step
        loss.backward()                 # backpropagation, compute gradients
        optimizer.step()                # apply gradients

        if step%50 == 0:
            accuracy, _ = evaluate(test_x, test_y)
            print(f'epoch: {epoch} | loss: {loss.detach().item()} | accuracy: {accuracy}')

epoch: 0 | loss: 2.3022775650024414 | accuracy: 0.09589999914169312
epoch: 0 | loss: 2.147660255432129 | accuracy: 0.4860000014305115
epoch: 0 | loss: 0.9360531568527222 | accuracy: 0.6309000253677368
epoch: 0 | loss: 0.8220655918121338 | accuracy: 0.8097000122070312
epoch: 0 | loss: 0.20937883853912354 | accuracy: 0.8708999752998352
epoch: 0 | loss: 0.7453866004943848 | accuracy: 0.890500009059906
epoch: 0 | loss: 0.27988290786743164 | accuracy: 0.9071999788284302
epoch: 0 | loss: 0.7935819625854492 | accuracy: 0.9049000144004822
epoch: 0 | loss: 0.3070870339870453 | accuracy: 0.8985999822616577
epoch: 0 | loss: 0.5050113797187805 | accuracy: 0.9289000034332275
epoch: 0 | loss: 0.26430046558380127 | accuracy: 0.9395999908447266
epoch: 0 | loss: 0.2305402159690857 | accuracy: 0.9229999780654907
epoch: 0 | loss: 0.36997950077056885 | accuracy: 0.9391999840736389
epoch: 0 | loss: 0.24583810567855835 | accuracy: 0.9394000172615051
epoch: 0 | loss: 0.3160313069820404 | accuracy: 0.94739997