In [None]:
## 인공지능을 위한 수학 2 
## 제 6강 : MNIST Classification using PyTorch


## Module Import 


from __future__ import print_function
from torch import nn, optim, cuda
from torch.utils import data
from torchvision import datasets, transforms
import torch.nn.functional as F
import time
import torch


## Google Drive Mount 

from google.colab import drive
drive.mount('/content/my_drive')


## Trainig Setup

batch_size = 64
device = 'cuda' if cuda.is_available() else 'cpu'
print(f'Training MNIST Model on {device}\n{"="*44}')


## MNIST dataset Download 
train_dataset = datasets.MNIST(root='/content/my_drive/MyDrive/mai1/mnist_data/', 
                              train=True, 
                              transform=transforms.ToTensor(),
                              download=True)
test_dataset = datasets.MNIST(root='/content/my_drive/MyDrive/mai1/mnist_data/',
                             train=False,
                             transform=transforms.ToTensor())

## Data loader
train_loader = data.DataLoader(dataset=train_dataset,
                               batch_size=batch_size,
                               shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset,
                              batch_size=batch_size,
                              shuffle=False)


## Network Architecture Definition

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()     #사실 Net하고 self를 안 해줘도 잘 찾아가긴함
    self.l1 = nn.Conv2d(1, 16, 3, padding=1)    #28x28x1 -> 28x28x16 -> 28x28x16 -> 28x28x32
    self.l2 = nn.Linear(16, 16, 3, padding=1)
    self.l3 = nn.Linear(16, 32, 3, padding=1)
    self.l4 = nn.Linear(1568, 120)
    self.l5 = nn.Linear(120, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.l1(x)), 2)
    x = F.relu(self.l2(x))
    x = F.max_pool2d(F.relu(self.l3(x)), 2)
    x = x.view(-1, 1568)
    x = F.relu(self.l4(x))
    return self.l5(x)

## Network Load

model = Net()
model.to(device)

## Model Load

# model.load_state_dict(torch.load("004_my_model.pt"))   예를 들어 4번째까지 하고 멈췄을 때 이 weight를 가지고 시작하려고 할 때 사용


## Training Setup

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

## Model Saving    이를 통해서 weight와 bias를 알 수 있다. model.state_dict() 안에 저장되어 있다.

#for param_tensor in model.state_dict():
#  print(param_tensor, "\t", model.state_dict()[param_tensor].size()) #이걸 출력하면 엄청난 크기의 weight들이 출력됨

## Train Function 

def train(epoch):
  model.train()    #train을 위한 형식
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()     #학습을 완료할 때마다 gradient들을 0으로 초기화 해줘야 한다.
    output = model(data)      #output 계산
    loss = criterion(output, target)     #loss 계산
    loss.backward()     #편미분 값 계산
    optimizer.step()     #weight update
    if batch_idx % 10 == 0:
      print('Train Epoch : {} | Batch Status : {}/{} ({:.0f}%) | Loss : {:.6f}'.format(
          epoch, batch_idx*len(data), len(train_loader.dataset),
          100. * batch_idx / len(train_loader), loss.item()))

    ## Save weight parameter
    torch.save(model.state_dict(), "%03d_my_model.pt"%epoch)

    #좀 더 세세하게 저장하려면 밑처럼
    '''
    torch.save({'epoch' : epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                }, "%03d_my_model.pt"%epoch)
    '''

## Test Function

def test():
  model.eval()    #test를 위한 형식
  test_loss = 0
  correct = 0
  for data, target in test_loader:      #답을 모를땐 target 부분이 존재하지 않는다. MNIST의 경우 답이 존재하므로 답 체크를 위해 써놓는 것
    data, target = data.to(device), target.to(device)
    output = model(data)

    # sum up batch loss
    test_loss += criterion(output, target).item()

    # get the index of the max
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print(f'==================\nTest set: Average loss : {test_loss:.4f}, Accuracy : {correct}/{len(test_loader.dataset)}'
        f'({100. * correct / len(test_loader.dataset):.0f}%)')
  

## Main

if __name__ == '__main__':
 
  since = time.time()
  for epoch in range(1, 10):
    epoch_start = time.time()
    train(epoch)
    m, s = divmod(time.time() - epoch_start, 60)
    print(f'Training time: {m:.0f}m {s:.0f}s')
    
    test()
    m, s = divmod(time.time() - epoch_start, 60)
    print(f'Tesing time: {m:.0f}m {s:.0f}s')

  m, s = divmod(time.time() - epoch_start, 60)
  print(f'Total time : {m:.0f}m {s: .0f}s \nModel was trained on {device}!')