## 변이형 오토인코더 훈련

### 라이브러리 임포트

In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.utils.data as data_utils
import torch.nn.functional as F

from vae_auto_encoder import VAEAutoEncoder

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [3]:
num_epochs = 200
batch_size = 64
learning_rate = 1e-3
r_loss_factor = 1000
decay_factor = 0.99
data_path = '../data/'
model_save_path = "./vae_digits_model.pth"

### 데이터 적재

In [4]:
train_data = datasets.MNIST(root=data_path,
                            train=True,
                            download=True,
                            transform=transforms.ToTensor())
dataset_size = len(train_data)
train_loader = DataLoader(dataset=train_data,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          pin_memory=True)

In [5]:
print("Dataset size:", dataset_size)

Dataset size: 60000


### 모델 만들기

In [6]:
model = VAEAutoEncoder(num_layers=4,
                       encoder_channels=[1, 32, 64, 64, 64],
                       encoder_kernel_sizes=[3, 3, 3, 3],
                       encoder_strides=[1, 2, 2, 1],
                       decoder_channels=[64, 64, 64, 32, 1],
                       decoder_kernel_sizes=[3, 3, 3, 3],
                       decoder_strides=[1, 2, 2, 1],
                       linear_sizes=[3136, 2],
                       view_size=[-1, 64, 7, 7],
                       use_batch_norm=False,
                       use_dropout=False).to(device)
model.train()

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                        lr_lambda=(lambda epoch: decay_factor ** epoch))

print(model)

VAEAutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.01, inplace=True)
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (3): LeakyReLU(negative_slope=0.01, inplace=True)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (5): LeakyReLU(negative_slope=0.01, inplace=True)
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): LeakyReLU(negative_slope=0.01, inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
  )
  (mu_layer): Linear(in_features=3136, out_features=2, bias=True)
  (log_var_layer): Linear(in_features=3136, out_features=2, bias=True)
  (linear): Linear(in_features=2, out_features=3136, bias=True)
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_

### 모델 훈련

In [7]:
def vae_kl_loss(mu, log_var):
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), axis=1)
    return torch.mean(kl_loss)

r_criterion = nn.MSELoss()

In [8]:
for epoch in range(num_epochs):
    running_loss = 0.0
    running_r_loss = 0.0
    running_kl_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            outputs, mu, log_var = model(inputs)
            r_loss = r_loss_factor * r_criterion(outputs, inputs)
            kl_loss = vae_kl_loss(mu, log_var)
            loss = r_loss + kl_loss
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        running_r_loss += r_loss.item() * inputs.size(0)
        running_kl_loss += kl_loss.item() * inputs.size(0)
    
    scheduler.step()    
    epoch_loss = running_loss / dataset_size
    epoch_r_loss = running_r_loss / dataset_size
    epoch_kl_loss = running_kl_loss / dataset_size
    print('Epoch {0:03d}\tLoss: {1:0.5f}\tr_loss: {2:0.5f}\tkl_loss: {3:0.5f}'.format(
        epoch + 1, epoch_loss, epoch_r_loss, epoch_kl_loss))

    torch.save(model.state_dict(), model_save_path)
    

Epoch 001	Loss: 57.06260	r_loss: 53.84807	kl_loss: 3.21454
Epoch 002	Loss: 49.17750	r_loss: 44.72828	kl_loss: 4.44922
Epoch 003	Loss: 47.68091	r_loss: 42.94135	kl_loss: 4.73956
Epoch 004	Loss: 46.83676	r_loss: 41.95326	kl_loss: 4.88349
Epoch 005	Loss: 46.32283	r_loss: 41.34881	kl_loss: 4.97402
Epoch 006	Loss: 45.92388	r_loss: 40.84339	kl_loss: 5.08049
Epoch 007	Loss: 45.59897	r_loss: 40.47061	kl_loss: 5.12836
Epoch 008	Loss: 45.30886	r_loss: 40.12167	kl_loss: 5.18719
Epoch 009	Loss: 45.05550	r_loss: 39.82849	kl_loss: 5.22701
Epoch 010	Loss: 44.87323	r_loss: 39.58311	kl_loss: 5.29011
Epoch 011	Loss: 44.68638	r_loss: 39.37386	kl_loss: 5.31252
Epoch 012	Loss: 44.49468	r_loss: 39.14780	kl_loss: 5.34689
Epoch 013	Loss: 44.38228	r_loss: 39.00866	kl_loss: 5.37362
Epoch 014	Loss: 44.19775	r_loss: 38.79111	kl_loss: 5.40664
Epoch 015	Loss: 44.08610	r_loss: 38.65552	kl_loss: 5.43058
Epoch 016	Loss: 43.98302	r_loss: 38.54708	kl_loss: 5.43594
Epoch 017	Loss: 43.81777	r_loss: 38.34772	kl_loss: 5.470