# import torch

In [None]:
import mnist_train as mt

In [None]:
import time
import torch
from torch import optim
torch.__version__

# GPU check

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# Download MNIST dataset

## What is MNIST dataset?
> MNIST 데이터베이스 (Modified National Institute of Standards and Technology database)는  
손으로 쓴 숫자들로 이루어진 대형 데이터베이스이며,   
다양한 화상 처리 시스템을 트레이닝하기 위해 일반적으로 사용된다.   
이 데이터베이스는 또한 기계 학습 분야의 트레이닝 및 테스트에 널리 사용된다.  
>  
> https://ko.wikipedia.org/wiki/MNIST_데이터베이스

In [None]:
from torchvision import datasets
from torchvision.transforms import ToTensor
import numpy as np
import matplotlib.pyplot as plt

train_data = datasets.MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)
test_data = datasets.MNIST(
    root = 'data', 
    train = False, 
    transform = ToTensor()
)


In [None]:
# train data 특성
print(train_data)
print(train_data.data.size())

In [None]:
# 데이터에 어떤 내용이 있는지 확인
img, label = train_data[0]
print(img.shape)
print(label)

In [None]:
mt.plot_train_data(img, label)

# Plot multiple train_data

In [None]:
mt.plot_multiple_train_data(train_data, rows=5, cols=8)

# Dataloader

설정된 `batch_size` 단위로 데이터를 가져올 수 있도록 하는 모듈

In [None]:
from torch.utils.data import DataLoader
loaders = {
    'train' : torch.utils.data.DataLoader(train_data, 
                                          batch_size=100, 
                                          shuffle=True),
    
    'test'  : torch.utils.data.DataLoader(test_data, 
                                          batch_size=100, 
                                          shuffle=True, ),
}
loaders

# Defile Model

In [None]:
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
    
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(                        # 28*28 -> 24*24  
                in_channels=1,              
                out_channels=32,            
                kernel_size=5,                      
                stride=1,                   
                padding=0,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),     # 24*24 -> 12*12
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(32, 64, 5, 1, 0),      # 12*12 -> 8*8
            nn.ReLU(),                      
            nn.MaxPool2d(2),                 # 8*8 -> 4*4
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(64 * 4 * 4, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        output = F.softmax(output, dim=1)
        
        return output    # return x for visualization

In [None]:
model = CNN()
model = model.to(device)

In [None]:
#!pip install torchsummary
from torchsummary import summary
summary(model, input_size=(1,28,28))

In [None]:
mt.show_sample_predict_cnn(model, device, test_data)

# Train model

In [None]:
from torch.autograd import Variable

def train(model, loaders, num_epochs, loss_func, optimizer, train_loss_list:list, test_loss_list:list):
    
    # Train the model
    total_step = len(loaders['train'])
        
    for epoch in range(num_epochs):
        loss_dict = {
            'train': 0.,
            'test': 0.
        }        
        start_time = time.time()
        
        for phase in ['train', 'test']:
        #for phase in ['train']:
            
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            for i, (images, labels) in enumerate(loaders[phase]):

                images = images.to(device, dtype=torch.float32)
                labels = labels.to(device)
                # gives batch data, normalize x when iterate train_loader
                b_x = Variable(images)   # batch x
                b_y = Variable(labels)   # batch y

                output = model(b_x)             

                loss = loss_func(output, b_y)
                loss_dict[phase] += loss.item()
                
                if phase == 'train':
                    optimizer.zero_grad()     # clear gradients for this training step               
                    loss.backward()           # backpropagation, compute gradients         
                    optimizer.step()          # apply gradients                 

                    # batch 100번 마다 로그 찍기
                    if (i+1) % 100 == 0:
                        print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.6f}' 
                               .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
    
        train_loss = loss_dict['train'] / len(loaders['train']) 
        test_loss = loss_dict['test'] / len(loaders['test']) 
        
        train_loss_list.append(train_loss)
        test_loss_list.append(test_loss)
        duration = time.time() - start_time
        print(f"Epoch [{epoch+1}/{num_epochs}] summary, train_loss:{train_loss:8.8f}, "\
              f"test_loss:{test_loss:8.8f} duration: {duration:.1f}s")

    return model, train_loss_list, test_loss_list


In [None]:
model = CNN()
model = model.to(device)
loss_func = nn.CrossEntropyLoss()   
optimizer = optim.Adam(model.parameters(), lr = 0.0001)   

train_loss_list = []
test_loss_list = []

In [None]:
model, train_loss_list, test_loss_list = train(model, loaders, 10, loss_func, optimizer, train_loss_list, test_loss_list)

In [None]:
mt.draw_loss(train_loss_list, test_loss_list)

# Evaluate

In [None]:
def evaluate(model):
    # Test the model
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in loaders['test']:
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device)
            test_output = model(images)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = (pred_y == labels).sum().item() / float(labels.size(0))
            pass

        print(f'Test Accuracy of the model on the {len(test_data)} test images: {accuracy*100:.2f}%')
    
    pass
evaluate(model)

# Evaluation data sampling

In [None]:
figure = plt.figure(figsize=(10, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
    model.eval()
    sample_idx = np.random.randint(len(test_data), size=(1,)).item()
    img, gt = test_data[sample_idx]
    img = img.to(device)
    predicted = model(img.view(1,1,28,28))
    #label = torch.argmax(predicted)
    confidence, label = torch.max(predicted, 1)
    label = label.cpu().item()
    figure.add_subplot(rows, cols, i)
    plt.title(f"{label}/(GT:{gt} / {gt==label})")
    plt.axis("off")
    plt.imshow(img.view(28,28).cpu(), cmap="gray")

plt.tight_layout()
plt.show()

# Test

In [None]:
mt.show_sample_predict_cnn(model, device, test_data)