In [29]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms

In [30]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
training_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
validation_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

training_loader = torch.utils.data.DataLoader(dataset=training_dataset, batch_size=100, shuffle=True)
validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=100, shuffle=False)


In [37]:
class Net(nn.Module):

  def __init__(self):
    super().__init__()
    # 32px 이었다가, conv를 거치면서 절반으로 감소
    self.conv1 = nn.Conv2d(1, 16, 3, 1, padding=1)
    self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1)
    self.conv3 = nn.Conv2d(32, 64, 3, 1, padding=1)
    self.fc1 = nn.Linear(4*4*64, 500)
    self.dropout1 = nn.Dropout(0.5)
    self.fc2 = nn.Linear(500, 10)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.max_pool2d(x, 2, 2)
    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, 2, 2)
    x = F.relu(self.conv3(x))
    x = F.max_pool2d(x, 2, 2)
    x = x.view(-1, 4*4*64)
    x = F.relu(self.fc1(x))
    x = self.dropout1(x)
    x = self.fc2(x)
    return x

In [None]:
model = Net().to(device)
torch.save(model.state_dict(), 'MNIST_BaseModel_weights.pth')

In [38]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [39]:
epochs = 12
running_loss_history = []
running_correct_history = []
validation_running_loss_history = []
validation_running_correct_history = []

for e in range(epochs):

  running_loss = 0.0
  running_correct = 0.0
  validation_running_loss = 0.0
  validation_running_correct = 0.0

  for inputs, labels in training_loader:

    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs = model(inputs)
    loss = criterion(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    _, preds = torch.max(outputs, 1)

    running_correct += torch.sum(preds == labels.data)
    running_loss += loss.item()



  else:
    # 훈련팔 필요가 없으므로 메모리 절약
    with torch.no_grad():

      for val_input, val_label in validation_loader:

        val_input = val_input.to(device)
        val_label = val_label.to(device)
        val_outputs = model(val_input)
        val_loss = criterion(val_outputs, val_label)

        _, val_preds = torch.max(val_outputs, 1)
        validation_running_loss += val_loss.item()
        validation_running_correct += torch.sum(val_preds == val_label.data)


    epoch_loss = running_loss / len(training_loader)
    epoch_acc = running_correct.float() / len(training_loader)
    running_loss_history.append(epoch_loss)
    running_correct_history.append(epoch_acc)

    val_epoch_loss = validation_running_loss / len(validation_loader)
    val_epoch_acc = validation_running_correct.float() / len(validation_loader)
    validation_running_loss_history.append(val_epoch_loss)
    validation_running_correct_history.append(val_epoch_acc)

    print("===================================================")
    print("epoch: ", e + 1)
    print("training loss: {:.5f}, acc: {:5f}".format(epoch_loss, epoch_acc))
    print("validation loss: {:.5f}, acc: {:5f}".format(val_epoch_loss, val_epoch_acc))

epoch:  1
training loss: 1.72446, acc: 37.391998
validation loss: 1.55120, acc: 43.720001
epoch:  2
training loss: 1.44779, acc: 47.716000
validation loss: 1.39370, acc: 49.570000
epoch:  3
training loss: 1.33785, acc: 51.681999
validation loss: 1.31821, acc: 52.459999
epoch:  4
training loss: 1.25203, acc: 55.012001
validation loss: 1.27294, acc: 54.060001
epoch:  5
training loss: 1.18392, acc: 57.695999
validation loss: 1.21749, acc: 56.830002
epoch:  6
training loss: 1.12857, acc: 59.782001
validation loss: 1.17000, acc: 58.200001
epoch:  7
training loss: 1.07812, acc: 61.966000
validation loss: 1.13756, acc: 59.720001
epoch:  8
training loss: 1.03751, acc: 63.125999
validation loss: 1.11344, acc: 60.570000
epoch:  9
training loss: 0.99810, acc: 64.706001
validation loss: 1.11099, acc: 61.459999
epoch:  10
training loss: 0.97338, acc: 65.606003
validation loss: 1.09933, acc: 61.610001
epoch:  11
training loss: 0.93950, acc: 66.776001
validation loss: 1.08968, acc: 61.980000
epoch:  