In [10]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.utils.data as data_utils
import torchvision.utils as vutils
import matplotlib.pyplot as plt

from vae_auto_encoder import VAEAutoEncoder

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [12]:
num_epochs = 200
batch_size = 32
image_size = 128
learning_rate = 5e-4
decay_factor = 0.99
r_loss_factor = 10000
z_dim_size = 200

# data url: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
# download "Align&Cropped Images.zip"
data_path = './data/celeba/'
model_save_path = './vae_faces_model.pth'
save_folder = './images/celeba'

if not os.path.exists(save_folder):
    os.mkdir(save_folder)
    os.mkdir(os.path.join(save_folder, 'viz'))
    os.mkdir(os.path.join(save_folder, 'images'))
    os.mkdir(os.path.join(save_folder, 'weights'))

In [13]:
dataset = datasets.ImageFolder(root=data_path,
                               transform=transforms.Compose([
                                   transforms.Resize((image_size, image_size)),
                                   transforms.ToTensor(),
                                   #transforms.Normalize((0.5, 0.5, 0.5),
                                   #                     (0.5, 0.5, 0.5))
                               ]))

dataloader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4,
                        pin_memory=True)


In [14]:
model = VAEAutoEncoder(4,
                       encoder_channels=[3, 32, 64, 64, 64],
                       encoder_kernel_sizes=[3, 3, 3, 3],
                       encoder_strides=[2, 2, 2, 2],
                       decoder_channels=[64, 64, 64, 32, 3],
                       decoder_kernel_sizes=[3, 3, 3, 3],
                       decoder_strides=[2, 2, 2, 2],
                       linear_sizes=[4096, z_dim_size],
                       view_size=[-1, 64, 8, 8],
                       use_batch_norm=True,
                       use_dropout=True).to(device)
model.train()

print(model)

VAEAutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01, inplace=True)
    (3): Dropout(p=0.25, inplace=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): LeakyReLU(negative_slope=0.01, inplace=True)
    (7): Dropout(p=0.25, inplace=False)
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.01, inplace=True)
    (11): Dropout(p=0.25, inplace=False)
    (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tr

In [15]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                        lr_lambda=(lambda epoch: decay_factor ** epoch))

In [16]:
def vae_kl_loss(mu, log_var):
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), axis=1)
    return torch.mean(kl_loss)

r_criterion = nn.MSELoss()

Train

In [17]:
for epoch in range(num_epochs):
    running_loss = 0.0
    running_r_loss = 0.0
    running_kl_loss = 0.0
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            outputs, mu, log_var = model(inputs)
            r_loss = r_loss_factor * r_criterion(outputs, inputs)
            kl_loss = vae_kl_loss(mu, log_var)
            loss = r_loss + kl_loss
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        running_r_loss += r_loss.item() * inputs.size(0)
        running_kl_loss += kl_loss.item() * inputs.size(0)
        
    scheduler.step()
    epoch_loss = running_loss / len(dataset)
    epoch_r_loss = running_r_loss / len(dataset)
    epoch_kl_loss = running_kl_loss / len(dataset)
    print('Epoch {0:03d}\tLoss: {1:0.5f}\tr_loss: {2:0.5f}\tkl_loss: {3:0.5f}'.format(
        epoch, epoch_loss, epoch_r_loss, epoch_kl_loss))

    torch.save(model.state_dict(), model_save_path)

Epoch 000	Loss: 278.00420	r_loss: 222.44046	kl_loss: 55.56374
Epoch 001	Loss: 230.98183	r_loss: 173.46623	kl_loss: 57.51560
Epoch 002	Loss: 224.32764	r_loss: 166.83930	kl_loss: 57.48834
Epoch 003	Loss: 220.66469	r_loss: 163.35540	kl_loss: 57.30930
Epoch 004	Loss: 218.34081	r_loss: 161.07212	kl_loss: 57.26869
Epoch 005	Loss: 216.53039	r_loss: 159.34907	kl_loss: 57.18133
Epoch 006	Loss: 215.33851	r_loss: 158.17660	kl_loss: 57.16191
Epoch 007	Loss: 214.38437	r_loss: 157.22137	kl_loss: 57.16300
Epoch 008	Loss: 213.43553	r_loss: 156.33048	kl_loss: 57.10505
Epoch 009	Loss: 212.73868	r_loss: 155.62917	kl_loss: 57.10951
Epoch 010	Loss: 212.15582	r_loss: 155.00868	kl_loss: 57.14714
Epoch 011	Loss: 211.65693	r_loss: 154.51514	kl_loss: 57.14179
Epoch 012	Loss: 211.19617	r_loss: 154.05045	kl_loss: 57.14571
Epoch 013	Loss: 210.86381	r_loss: 153.70611	kl_loss: 57.15771
Epoch 014	Loss: 210.53479	r_loss: 153.34663	kl_loss: 57.18816
Epoch 015	Loss: 210.23012	r_loss: 153.04756	kl_loss: 57.18256
Epoch 01