In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
trn_dataset = datasets.MNIST('../mnist_data/',
                            download=True,
                            train=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,),(0.3081,))
                            ]))

In [3]:
val_dataset = datasets.MNIST('../mnist_data/',
                            download=False,
                            train=False,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,),(0.3081,))
                            ]))

In [4]:
batch_size = 64
trn_loader = torch.utils.data.DataLoader(trn_dataset,
                                        batch_size=batch_size,
                                        shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset,
                                        batch_size=batch_size,
                                        shuffle=True)

In [5]:
use_cuda = torch.cuda.is_available()

In [6]:
class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        conv1 = nn.Conv2d(1,6,5,1)  #6@24*24
        # activation relu
        pool1 = nn.MaxPool2d(2)
        conv2 = nn.Conv2d(6,16,5,1)  #16@8*8
        # activation relu
        pool2 = nn.MaxPool2d(2)  #16@4*4
        
        self.conv_module = nn.Sequential(
                    conv1,
                    nn.ReLU(),
                    pool1,
                    conv2,
                    nn.ReLU(),
                    pool2)
        fc1 = nn.Linear(16*4*4, 120)
        # activation relu
        fc2 = nn.Linear(120, 84)
        # activation relu
        fc3 = nn.Linear(84, 10)
        
        self.fc_module = nn.Sequential(
                            fc1,
                            nn.ReLU(),
                            fc2,
                            nn.ReLU(),
                            fc3)
        # gpu할당
        if use_cuda:
            self.conv_module = self.conv_module.cuda()
            self.fc_module = self.fc_module.cuda()
            
    def forward(self, x):
        out = self.conv_module(x)
        # make linear
        dim = 1
        for d in out.size()[1:]:
            dim = dim*d
        out = out.view(-1, dim)
        out = self.fc_module(out)
        return F.softmax(out, dim=1)

In [7]:
# model
cnn = CNNClassifier()
# loss
criterion = nn.CrossEntropyLoss()
# backpropagation
learning_rate = 1e-3
optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)
# hypter-parameters
num_epochs = 2
num_batches = len(trn_loader)

In [None]:
trn_loss_list, val_loss_list = [], []
for epoch in range(num_epochs):
    trn_loss = 0
    for i, data in enumerate(trn_loader):
        x, label = data
        if use_cuda:
            x, label = x.cuda(), label.cuda()
        # grad_init
        optimizer.zero_grad()
        # forward propagation
        model_output = cnn(x)
        # calculate_loss
        loss = criterion(model_output, label)
        # back propagation
        loss.backward()
        # weight update
        optimizer.step()
        
        # trn_loss summary
        trn_loss += loss.item()
        # del(memory issue)
        del loss
        del model_output
        
        # 학습과정 출력
        if (i+1) % 100 == 0:
            with torch.no_grad():
                val_loss = 0
                for j, val in enumerate(val_loader):
                    val_x, val_label = val
                    if use_cuda:
                        val_x, val_label = val_x.cuda(), val_label.cuda()
                    val_output = cnn(val_x)
                    v_loss = criterion(val_output, val_label)
                    val_loss += v_loss
                    
            print("epoch: {}/{} | step: {}/{} | trn loss: {:.4f} | val loss: {:.4f}".format(
                    epoch+1, num_epochs, i+1, num_batches, trn_loss / 100, val_loss / len(val_loader)))
            
            trn_loss_list.append(trn_loss/100)
            val_loss_list.append(val_loss/len(val_loader))
            trn_loss = 0

epoch: 1/2 | step: 100/938 | trn loss: 1.6237 | val loss: 1.5429
epoch: 1/2 | step: 200/938 | trn loss: 1.5407 | val loss: 1.5239
epoch: 1/2 | step: 300/938 | trn loss: 1.5240 | val loss: 1.5115
epoch: 1/2 | step: 400/938 | trn loss: 1.5151 | val loss: 1.5078
epoch: 1/2 | step: 500/938 | trn loss: 1.5086 | val loss: 1.5069
epoch: 1/2 | step: 600/938 | trn loss: 1.5079 | val loss: 1.4973
epoch: 1/2 | step: 700/938 | trn loss: 1.5008 | val loss: 1.4983
epoch: 1/2 | step: 800/938 | trn loss: 1.5012 | val loss: 1.4984
epoch: 1/2 | step: 900/938 | trn loss: 1.5004 | val loss: 1.4933
epoch: 2/2 | step: 100/938 | trn loss: 1.4946 | val loss: 1.4877
epoch: 2/2 | step: 200/938 | trn loss: 1.4883 | val loss: 1.4919
epoch: 2/2 | step: 300/938 | trn loss: 1.4944 | val loss: 1.4865
epoch: 2/2 | step: 400/938 | trn loss: 1.4914 | val loss: 1.4966
epoch: 2/2 | step: 500/938 | trn loss: 1.4914 | val loss: 1.4891
epoch: 2/2 | step: 600/938 | trn loss: 1.4901 | val loss: 1.4814
epoch: 2/2 | step: 700/93