## 중요 모듈 import

In [1]:
import os
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

## 데이터 셋

In [2]:
root = './data'
if not os.path.exists(root):
    os.mkdir(root)

In [3]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = dset.MNIST(root=root, train=True, transform=trans, download=True)
test_set = dset.MNIST(root=root, train=False, transform=trans, download=True)

## Hyper Parameters

In [4]:
# Hyper Parameters
batch_size = 100
total_epoch = 10
learning_rate = 0.1
use_cuda = torch.cuda.is_available()

## Data Loader

In [5]:
# Data Loader
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)
print('==>> total trainning batch number: {}'.format(len(train_loader)))
print('==>> total testing batch number: {}'.format(len(train_loader)))

==>> total trainning batch number: 600
==>> total testing batch number: 600


## MNIST MLP Model

In [6]:
# MNIST MLP Model

class MLPNet(nn.Module):
    def __init__(self):
        super(MLPNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))

        x = self.fc4(x)
        
        return F.log_softmax(x, dim=1)
    
    def name(self):
        return "MLP"

In [7]:
model = MLPNet()
if use_cuda:
    model = model.cuda()

In [8]:
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss() # 이미지 분류 모델은 일반적으로 CrossEntropy Loss Function을 사용

## 모델 학습

In [9]:
for epoch in range(total_epoch):
    # trainning
    total_loss = 0
    total_batch = 0
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        if use_cuda:
            x, target = x.cuda(), target.cuda()

        out = model(x)
        loss = criterion(out, target)
        total_loss += loss.item()
        total_batch += 1
        loss.backward()
        optimizer.step()
        if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == len(train_loader):
            print('==>>> epoch : {}, batch index : {}, train loss : {:.6f}'
                  .format(epoch, batch_idx+1, total_loss/total_batch))
    # testing
    total_loss = 0
    total_batch = 0
    correct_cnt = 0
    total_cnt = 0
    for batch_idx, (x, target) in enumerate(test_loader):
        if use_cuda:
            x, target = x.cuda(), target.cuda()

        out = model(x)
        loss = criterion(out, target)
        _, pred_label = torch.max(out.data, 1)
        total_cnt += x.data.size()[0]
        correct_cnt += (pred_label == target.data).sum().item()

        total_loss += loss.item()
        total_batch += 1

        if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == len(test_loader):
            print('==>>> epoch : {}, batch index : {}, train loss : {:.6f}, acc: {:.3f}'
                  .format(epoch, batch_idx + 1, total_loss / total_batch, correct_cnt * 1.0 / total_cnt))

==>>> epoch : 0, batch index : 100, train loss : 2.311269
==>>> epoch : 0, batch index : 200, train loss : 2.310656
==>>> epoch : 0, batch index : 300, train loss : 2.310002
==>>> epoch : 0, batch index : 400, train loss : 2.309495
==>>> epoch : 0, batch index : 500, train loss : 2.308900
==>>> epoch : 0, batch index : 600, train loss : 2.308591
==>>> epoch : 0, batch index : 100, train loss : 2.307865, acc: 0.089
==>>> epoch : 1, batch index : 100, train loss : 2.306478
==>>> epoch : 1, batch index : 200, train loss : 2.306491
==>>> epoch : 1, batch index : 300, train loss : 2.306355
==>>> epoch : 1, batch index : 400, train loss : 2.305756
==>>> epoch : 1, batch index : 500, train loss : 2.305622
==>>> epoch : 1, batch index : 600, train loss : 2.305267
==>>> epoch : 1, batch index : 100, train loss : 2.303234, acc: 0.101
==>>> epoch : 2, batch index : 100, train loss : 2.304673
==>>> epoch : 2, batch index : 200, train loss : 2.304257
==>>> epoch : 2, batch index : 300, train loss :