## 변이형 오토인코더 훈련 - 얼굴 데이터셋

### 라이브러리 임포트

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.utils.data as data_utils
import torchvision.utils as vutils
import matplotlib.pyplot as plt

from vae_auto_encoder import VAEAutoEncoder

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [3]:
num_epochs = 200
batch_size = 32
image_size = 128
learning_rate = 5e-4
decay_factor = 0.99
r_loss_factor = 10000
z_dim_size = 200

# data url: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
# download "Align&Cropped Images.zip"
data_path = '../data/celeba/'
model_save_path = './vae_faces_model.pth'
image_save_folder = './images/vae_celeba'

os.makedirs(image_save_folder, exist_ok=True)

### 데이터 적재

In [4]:
dataset = datasets.ImageFolder(root=data_path,
                               transform=transforms.Compose([
                                   transforms.Resize((image_size, image_size)),
                                   transforms.ToTensor()
                               ]))
dataset_size = len(dataset)

dataloader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4,
                        pin_memory=True)

In [5]:
print("Dataset size:", dataset_size)

Dataset size: 202599


### 모델 만들기

In [6]:
model = VAEAutoEncoder(4,
                       encoder_channels=[3, 32, 64, 64, 64],
                       encoder_kernel_sizes=[3, 3, 3, 3],
                       encoder_strides=[2, 2, 2, 2],
                       decoder_channels=[64, 64, 64, 32, 3],
                       decoder_kernel_sizes=[3, 3, 3, 3],
                       decoder_strides=[2, 2, 2, 2],
                       linear_sizes=[4096, z_dim_size],
                       view_size=[-1, 64, 8, 8],
                       use_batch_norm=True,
                       use_dropout=True).to(device)
model.train()

print(model)

VAEAutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01, inplace=True)
    (3): Dropout(p=0.25, inplace=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): LeakyReLU(negative_slope=0.01, inplace=True)
    (7): Dropout(p=0.25, inplace=False)
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.01, inplace=True)
    (11): Dropout(p=0.25, inplace=False)
    (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tr

### 모델 훈련

In [7]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                        lr_lambda=(lambda epoch: decay_factor ** epoch))

In [8]:
def vae_kl_loss(mu, log_var):
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), axis=1)
    return torch.mean(kl_loss)

r_criterion = nn.MSELoss()

In [9]:
for epoch in range(num_epochs):
    running_loss = 0.0
    running_r_loss = 0.0
    running_kl_loss = 0.0
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            outputs, mu, log_var = model(inputs)
            r_loss = r_loss_factor * r_criterion(outputs, inputs)
            kl_loss = vae_kl_loss(mu, log_var)
            loss = r_loss + kl_loss
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        running_r_loss += r_loss.item() * inputs.size(0)
        running_kl_loss += kl_loss.item() * inputs.size(0)
        
    scheduler.step()
    epoch_loss = running_loss / dataset_size
    epoch_r_loss = running_r_loss / dataset_size
    epoch_kl_loss = running_kl_loss / dataset_size
    print('Epoch {0:03d}\tLoss: {1:0.5f}\tr_loss: {2:0.5f}\tkl_loss: {3:0.5f}'.format(
        epoch + 1, epoch_loss, epoch_r_loss, epoch_kl_loss))

    torch.save(model.state_dict(), model_save_path)

Epoch 001	Loss: 281.82482	r_loss: 225.69137	kl_loss: 56.13345
Epoch 002	Loss: 232.89336	r_loss: 175.44581	kl_loss: 57.44755
Epoch 003	Loss: 225.86614	r_loss: 168.16207	kl_loss: 57.70407
Epoch 004	Loss: 222.05632	r_loss: 164.50038	kl_loss: 57.55594
Epoch 005	Loss: 219.57540	r_loss: 162.14158	kl_loss: 57.43382
Epoch 006	Loss: 217.92410	r_loss: 160.52807	kl_loss: 57.39603
Epoch 007	Loss: 216.44781	r_loss: 159.07634	kl_loss: 57.37147
Epoch 008	Loss: 215.30489	r_loss: 157.95076	kl_loss: 57.35413
Epoch 009	Loss: 214.24610	r_loss: 156.90942	kl_loss: 57.33668
Epoch 010	Loss: 213.40035	r_loss: 156.07523	kl_loss: 57.32512
Epoch 011	Loss: 212.76377	r_loss: 155.46015	kl_loss: 57.30362
Epoch 012	Loss: 212.08660	r_loss: 154.78868	kl_loss: 57.29793
Epoch 013	Loss: 211.55340	r_loss: 154.25583	kl_loss: 57.29757
Epoch 014	Loss: 211.10217	r_loss: 153.79757	kl_loss: 57.30460
Epoch 015	Loss: 210.63634	r_loss: 153.33856	kl_loss: 57.29778
Epoch 016	Loss: 210.37029	r_loss: 153.07739	kl_loss: 57.29290
Epoch 01