# Get CelebA face images

In [None]:
!pip install -q kaggle

In [None]:
from google.colab import files

files.upload()

In [None]:
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d jessicali9530/celeba-dataset -p /content/celeba/ --unzip

# Import

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
import torchvision.utils as vutils

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

manual_seed = 999
torch.manual_seed(manual_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model

In [None]:
# means and stds taken from the original DCGAN paper
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

## 1) Encoder

In [None]:
# structure based on the original VAEGAN and DCGAN paper
class Encoder(nn.Module):
  def __init__(self, z_dim=128):
    super(Encoder, self).__init__()
    self.conv1 = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2, bias=False),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2, inplace=True)
    )
    self.conv2 = nn.Sequential(
        nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2, bias=False),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True)
    )
    self.conv3 = nn.Sequential(
        nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2, bias=False),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True)
    )
    self.fc = nn.Sequential(
        nn.Linear(256 * 8 * 8, 2048, bias=False),
        nn.BatchNorm1d(2048),
        nn.LeakyReLU(0.2, inplace=True)
    )
    self.mu = nn.Linear(2048, z_dim)
    self.log_var = nn.Linear(2048, z_dim)

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.fc(x.view(len(x), -1))
    mu = self.mu(x)
    log_var = self.log_var(x)
    return mu, log_var

## 2) Decoder

In [None]:
# structure based on the original VAEGAN and DCGAN paper
class Decoder(nn.Module):
  def __init__(self, z_dim=128):
    super(Decoder, self).__init__()
    self.fc = nn.Sequential(
        nn.Linear(z_dim, 256 * 8 * 8, bias=False),
        nn.BatchNorm1d(256 * 8 * 8),
        nn.ReLU(True)
    )
    self.convt1 = nn.Sequential(
        nn.ConvTranspose2d(256, 256, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(True)
    )
    self.convt2 = nn.Sequential(
        nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
        nn.BatchNorm2d(128),
        nn.ReLU(True)
    )
    self.convt3 = nn.Sequential(
        nn.ConvTranspose2d(128, 32, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
        nn.BatchNorm2d(32),
        nn.ReLU(True)
    )
    self.convt4 = nn.Sequential(
        nn.ConvTranspose2d(32, 3, kernel_size=5, stride=1, padding=2, bias=False),
        nn.Tanh()
    )

  def forward(self, z):
    z = self.fc(z)
    z = self.convt1(z.view(-1, 256, 8, 8))
    z = self.convt2(z)
    z = self.convt3(z)
    return self.convt4(z)

## 3) Discriminator

In [None]:
# structure based on the original VAEGAN and DCGAN paper
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.conv1 = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
        nn.LeakyReLU(0.2, inplace=True)
    )
    self.conv2 = nn.Sequential(
        nn.Conv2d(32, 128, kernel_size=5, stride=2, padding=2),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True)
    )
    self.conv3 = nn.Sequential(
        nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True)
    )
    self.conv4 = nn.Sequential(
        nn.Conv2d(256, 256, kernel_size=5, stride=2, padding=2),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True)
    )
    self.fc1 = nn.Sequential(
        nn.Linear(256 * 8 * 8, 512, bias=False),
        nn.BatchNorm1d(512),
        nn.LeakyReLU(0.2, inplace=True)
    )
    self.fc2 = nn.Sequential(
        nn.Linear(512, 1),
        nn.Sigmoid()
    )

  def forward(self, x_original, x_recon, x_sampled):
    x = torch.cat((x_original, x_recon, x_sampled))
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.fc1(x.view(len(x), -1))
    return self.fc2(x)

## 4) VAEGAN

In [None]:
class VAEGAN(nn.Module):
  def __init__(self, z_dim=128):
    super(VAEGAN, self).__init__()
    self.z_dim = z_dim
    self.encoder = Encoder(z_dim)
    self.decoder = Decoder(z_dim)
    self.discriminator = Discriminator()

    self.encoder.apply(weights_init)
    self.decoder.apply(weights_init)
    self.discriminator.apply(weights_init)

  def forward(self, x):
    x_original = x.clone().detach()

    mu, log_var = self.encoder(x)
    batch_size = len(mu)
    std = torch.exp(log_var * 0.5)
    epsilon = torch.normal(mean=torch.zeros(batch_size, self.z_dim), std=torch.ones(batch_size, self.z_dim)).to(device)
    z = epsilon * std + mu
    x_recon = self.decoder(z)

    z_sampled = torch.normal(mean=torch.zeros(batch_size, self.z_dim), std=torch.ones(batch_size, self.z_dim)).to(device)
    x_sampled = self.decoder(z_sampled)

    return mu, log_var, self.discriminator(x_original, x_recon, x_sampled)

# Train

## Hyperparameters

In [None]:
batch_size = 64
z_dim = 128
lr = 3e-4
gamma = 1 # to balance style error(GAN loss) and content error(reconstruction loss)
num_epochs = 5

## Dataloader

In [None]:
dataset = torchvision.datasets.ImageFolder(root="/content/celeba/img_align_celeba",
                                           transform=T.Compose([
                                               T.Resize(64),
                                               T.CenterCrop(64),
                                               T.ToTensor(),
                                               T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                           ]))
dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size=batch_size,
                                         shuffle=True, 
                                         num_workers=2)

## Main

In [None]:
real_batch = next(iter(dataloader))
vaegan = VAEGAN(z_dim).to(device)
criterion_BCE = nn.BCELoss()
criterion_MSE = nn.MSELoss()
optim_enc = optim.RMSprop(vaegan.encoder.parameters(), lr=lr)
optim_dec = optim.RMSprop(vaegan.decoder.parameters(), lr=lr)
optim_dis = optim.RMSprop(vaegan.discriminator.parameters(), lr=lr)

fixed_z = torch.randn(64, z_dim, device=device)
img_list = []
losses_enc = []
losses_dec = []
losses_dis = []
iters = 0

for epoch in range(num_epochs):
  for i, (data, _) in enumerate(dataloader):
    data = data.to(device)
    mu, log_var, dis = vaegan(data)

    # unpack the output of discriminator, input order was "x_original -> x_recon -> x_sampled"
    dis_x_original = dis[:batch_size, :]
    dis_x_recon = dis[batch_size:2*batch_size, :]
    dis_x_sampled = dis[2*batch_size:, :]
    real_label = torch.ones((batch_size, 1), requires_grad=False).to(device)
    fake_label = torch.zeros((batch_size, 1), requires_grad=False).to(device)
    bce_dis_x_original = criterion_BCE(dis_x_original, real_label)
    bce_dis_x_recon = criterion_BCE(dis_x_recon, fake_label)
    bce_dis_x_sampled = criterion_BCE(dis_x_sampled, fake_label)

    loss_prior = torch.mean(0.5 * torch.sum(torch.pow(mu,2) + torch.exp(log_var) - log_var - 1, dim=1))
    loss_recon = criterion_MSE(dis_x_original, dis_x_recon)
    loss_gan = bce_dis_x_original + bce_dis_x_recon + bce_dis_x_sampled

    loss_enc = loss_prior + loss_recon
    loss_dec = gamma * loss_recon - loss_gan
    loss_dis = loss_gan

    losses_enc.append(loss_enc.item())
    losses_dec.append(loss_dec.item())
    losses_dis.append(loss_dis.item())

    optim_enc.zero_grad()
    loss_enc.backward(inputs=list(vaegan.encoder.parameters()), retain_graph=True)
    optim_enc.step()

    optim_dec.zero_grad()
    loss_dec.backward(inputs=list(vaegan.decoder.parameters()),retain_graph=True)
    optim_dec.step()

    optim_dis.zero_grad()
    loss_dis.backward(inputs=list(vaegan.discriminator.parameters()))
    optim_dis.step()

    if iters % 500 == 0:
      with torch.no_grad():
        fake_img = vaegan.decoder(fixed_z).detach().cpu()
      img_list.append(vutils.make_grid(fake_img, padding=2, normalize=True))

    iters += 1

In [None]:
# source from https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
plt.figure(figsize=(10,5))
plt.plot(losses_enc, label="Encoder")
plt.plot(losses_dec, label="Decoder")
plt.plot(losses_dis, label="Discriminator")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# source from https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())