In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
transformer_train = []

augmenting_method = [transforms.RandomHorizontalFlip(),
                     transforms.RandomRotation(10),
                     transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
                     transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                     transforms.GaussianBlur(kernel_size=(7, 13), sigma=(0.1, 0.2)),]
N = len(augmenting_method)

for i in range(N):
    transformer_train.append(transforms.Compose([augmenting_method[i],
                                                 transforms.ToTensor(),
                                               ]))

transformer = transforms.Compose([transforms.ToTensor(),])
training_loader = []

for i in range(N):
    training_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transformer_train[i])
    training_loader.append(torch.utils.data.DataLoader(dataset=training_dataset, batch_size=100, shuffle=True))

validation_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=100, shuffle=False)

In [4]:
class Net(nn.Module):
    def __init__(self) :
        super(Net, self).__init__()
        self.C1_layer = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, padding=2),
            nn.ReLU()
        )
        self.P2_layer = nn.MaxPool2d(kernel_size=2, stride=2)
        self.C3_layer = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5),
            nn.ReLU()
        )
        self.P4_layer = nn.MaxPool2d(kernel_size=2, stride=2)
        self.C5_layer = nn.Sequential(
            nn.Linear(5*5*16, 120),
            nn.ReLU()
        )
        self.F6_layer = nn.Sequential(
            nn.Linear(120, 84),
            nn.ReLU()
        )
        self.F7_layer = nn.Linear(84, 10)
        
    
    def forward(self, x) :
        output = self.C1_layer(x)
        output = self.P2_layer(output)
        output = self.C3_layer(output)
        output = self.P4_layer(output)
        output = output.reshape(-1,5*5*16)
        output = self.C5_layer(output)
        output = self.F6_layer(output)
        output = self.F7_layer(output)
        return output

In [5]:
model = []
for _ in range(N):
    augmented_model = Net().to(device)
    model.append(augmented_model)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = []
for i in range(N):
    optimizer.append(torch.optim.Adam(model[i].parameters(), lr=0.001))

In [7]:
def train(device, model, criterion, optimizer, training_loader, validation_loader):
  epochs = 12
  running_loss_history = []
  running_correct_history = []
  validation_running_loss_history = []
  validation_running_correct_history = []

  for e in range(epochs):

    running_loss = 0.0
    running_correct = 0.0
    validation_running_loss = 0.0
    validation_running_correct = 0.0

    for inputs, labels in training_loader:

      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = model(inputs)
      loss = criterion(outputs, labels)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      _, preds = torch.max(outputs, 1)

      running_correct += torch.sum(preds == labels.data)
      running_loss += loss.item()



    else:
      # 훈련팔 필요가 없으므로 메모리 절약
      with torch.no_grad():

        for val_input, val_label in validation_loader:

          val_input = val_input.to(device)
          val_label = val_label.to(device)
          val_outputs = model(val_input)
          val_loss = criterion(val_outputs, val_label)

          _, val_preds = torch.max(val_outputs, 1)
          validation_running_loss += val_loss.item()
          validation_running_correct += torch.sum(val_preds == val_label.data)


      epoch_loss = running_loss / len(training_loader)
      epoch_acc = running_correct.float() / len(training_loader)
      running_loss_history.append(epoch_loss)
      running_correct_history.append(epoch_acc)

      val_epoch_loss = validation_running_loss / len(validation_loader)
      val_epoch_acc = validation_running_correct.float() / len(validation_loader)
      validation_running_loss_history.append(val_epoch_loss)
      validation_running_correct_history.append(val_epoch_acc)

      print("===================================================")
      print("epoch: ", e + 1)
      print("training loss: {:.5f}, acc: {:5f}".format(epoch_loss, epoch_acc))
      print("validation loss: {:.5f}, acc: {:5f}".format(val_epoch_loss, val_epoch_acc))

In [8]:
for i in range(N):
    train(device, model[i], criterion, optimizer[i], training_loader[i], validation_loader)
    torch.save(model[i].state_dict(), f'mnist_weights_{augmenting_method[i].__class__.__name__}.pth')

epoch:  1
training loss: 0.48205, acc: 84.798332
validation loss: 0.18599, acc: 93.959999
epoch:  2
training loss: 0.17125, acc: 94.645004
validation loss: 0.14407, acc: 95.309998
epoch:  3
training loss: 0.12290, acc: 96.221672
validation loss: 0.09540, acc: 96.879997
epoch:  4
training loss: 0.09701, acc: 96.940002
validation loss: 0.09601, acc: 96.939995
epoch:  5
training loss: 0.08128, acc: 97.453339
validation loss: 0.06892, acc: 97.689995
epoch:  6
training loss: 0.07068, acc: 97.763336
validation loss: 0.06356, acc: 97.839996
epoch:  7
training loss: 0.06508, acc: 97.933334
validation loss: 0.05546, acc: 98.139999
epoch:  8
training loss: 0.05612, acc: 98.211670
validation loss: 0.06526, acc: 97.930000
epoch:  9
training loss: 0.05401, acc: 98.264999
validation loss: 0.06082, acc: 97.970001
epoch:  10
training loss: 0.04787, acc: 98.423332
validation loss: 0.06566, acc: 97.779999
epoch:  11
training loss: 0.04368, acc: 98.625000
validation loss: 0.05315, acc: 98.290001
epoch:  