# 🤪 Variational Autoencoders - CelebA Faces

In this notebook, we'll walk through the steps required to train your own variational autoencoder on the CelebA faces dataset

In [None]:
import os, sys
from dotenv import load_dotenv

load_dotenv()
python_path = os.getenv('PYTHONPATH')
data_path = os.getenv('DATA_PATH')
if python_path:
    for path in python_path.split(os.pathsep):
        if path not in sys.path:
            sys.path.append(path)


import numpy as np
import pandas as pd
from PIL import Image
from scipy.stats import norm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

from tqdm import tqdm, trange

from notebooks.pt_utils import display

## 0. Parameters <a name="parameters"></a>

In [None]:
IMAGE_SIZE = 32
CHANNELS = 3
BATCH_SIZE = 256
NUM_FEATURES = 128
Z_DIM = 200
LEARNING_RATE = 0.0005
EPOCHS = 10
BETA = 2000
LOAD_MODEL = False

NUM_WORKERS = 24
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 1. Prepare the data <a name="prepare"></a>

In [None]:
class CelebADataset(Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        csv_path = os.path.join(root, 'list_attr_celeba.csv')
        img_path = os.path.join(root, 'img_align_celeba', 'img_align_celeba')

        df = pd.read_csv(csv_path)

        self.n_samples = len(df)
        self.image_files = [os.path.join(img_path, img_file) for img_file in df['image_id']]
        self.image_attrs = df.iloc[:, 1:].values.tolist()

    def __len__(self):
        return self.n_samples

    def __getitem__(self, index):
        image = Image.open(self.image_files[index])
        attrs = self.image_attrs[index]

        if self.transform:
            image = self.transform(image)
        
        if self.target_transform:
            attrs = self.target_transform(attrs)

        return image, attrs

In [None]:
train_dataset = CelebADataset(
    root=os.path.join(data_path, 'celeba-dataset'), 
    transform=transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
    ]),
    target_transform=torch.Tensor,
)

train_dataset, test_dataset = torch.utils.data.random_split(train_dataset, [0.9, 0.1])
train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [0.8, 0.2])

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, pin_memory_device='cuda')
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, pin_memory_device='cuda')
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, pin_memory_device='cuda')

In [None]:
images_sample, _ = next(iter(train_loader))
display(images_sample, cmap=None)

## 2. Build the variational autoencoder <a name="build"></a>

In [None]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, input_size, latent_size):

        super().__init__()
        
        c, w, h = input_size
        z_size = latent_size
        unflat_size = (128, 2, 2)
        flat_size = np.prod(unflat_size)


        class SampleZ(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                z_mean, z_logvar = x
                epsilon = torch.randn_like(z_logvar)
                return z_mean + torch.exp(0.5 * z_logvar) * epsilon
        
        class Encoder(nn.Module):
            def __init__(self):
                super().__init__()

                self.seq = nn.Sequential(
                    nn.Conv2d(c, NUM_FEATURES, 3, stride=2, padding=1),
                    nn.BatchNorm2d(num_features=NUM_FEATURES),
                    nn.LeakyReLU(),
                    nn.Conv2d(NUM_FEATURES, NUM_FEATURES, 3, stride=2, padding=1),
                    nn.BatchNorm2d(num_features=NUM_FEATURES),
                    nn.LeakyReLU(),
                    nn.Conv2d(NUM_FEATURES, NUM_FEATURES, 3, stride=2, padding=1),
                    nn.BatchNorm2d(num_features=NUM_FEATURES),
                    nn.LeakyReLU(),
                    nn.Conv2d(NUM_FEATURES, NUM_FEATURES, 3, stride=2, padding=1),
                    nn.BatchNorm2d(num_features=NUM_FEATURES),
                    nn.LeakyReLU(),

                    nn.Flatten(),
                )

                self.z_mean_branch = nn.Linear(flat_size, z_size)
                self.z_logvar_branch = nn.Linear(flat_size, z_size)
                self.z_sampler = SampleZ()

            def forward(self, x):
                x = self.seq(x)
                z_mean = self.z_mean_branch(x)
                z_logvar = self.z_logvar_branch(x)
                z = self.z_sampler([z_mean, z_logvar])

                return z, z_mean, z_logvar
            
        
        class Decoder(nn.Module):
            def __init__(self):
                super().__init__()

                self.seq = nn.Sequential(
                    nn.Linear(z_size, flat_size),
                    nn.BatchNorm1d(num_features=flat_size),
                    nn.LeakyReLU(),

                    nn.Unflatten(dim=1, unflattened_size=unflat_size),

                    nn.ConvTranspose2d(unflat_size[0], NUM_FEATURES, 3, stride=2, padding=1, output_padding=1),
                    nn.BatchNorm2d(num_features=NUM_FEATURES),
                    nn.LeakyReLU(),
                    nn.ConvTranspose2d(NUM_FEATURES, NUM_FEATURES, 3, stride=2, padding=1, output_padding=1),
                    nn.BatchNorm2d(num_features=NUM_FEATURES),
                    nn.LeakyReLU(),
                    nn.ConvTranspose2d(NUM_FEATURES, NUM_FEATURES, 3, stride=2, padding=1, output_padding=1),
                    nn.BatchNorm2d(num_features=NUM_FEATURES),
                    nn.LeakyReLU(),
                    nn.ConvTranspose2d(NUM_FEATURES, NUM_FEATURES, 3, stride=2, padding=1, output_padding=1),
                    nn.BatchNorm2d(num_features=NUM_FEATURES),
                    nn.LeakyReLU(),

                    nn.ConvTranspose2d(NUM_FEATURES, c, 3, stride=1, padding=1),
                    nn.Sigmoid(),

                )

            def forward(self, x):
                return self.seq(x)
            
        
        self.encoder = Encoder()
        self.decoder = Decoder()

    
    def forward(self, x):
        z, _, _ = self.encoder(x)
        return self.decoder(z)
    
    def _construction_loss(self, reconstructed, original, beta=BETA):
        return beta * F.mse_loss(reconstructed, original)
    
    def _kl_loss(self, z_mean, z_logvar):
        return torch.mean(
            torch.sum(
                -0.5 * (1 + z_logvar - torch.square(z_mean) - torch.exp(z_logvar)),
                dim=1
            )
        )

    def fit(self, train_loader, valid_loader, epochs, optimizer, beta=BETA):

        with trange(epochs, desc='Epochs', unit='epoch', leave=True) as pbar:
            for epoch in range(epochs):

                # Train
                self.encoder.train()
                self.decoder.train()
                for images, _ in train_loader:
                    images = images.to(device)
                    optimizer.zero_grad()
                    z, z_mean, z_logvar = self.encoder(images)
                    reconstructions = self.decoder(z)
                    loss = (
                        self._construction_loss(reconstructions, images, beta) +
                        self._kl_loss(z_mean, z_logvar)
                    )
                    loss.backward()
                    optimizer.step()

                
                # Validation
                avg_val_loss = 0.0
                avg_recon_loss = 0.0
                avg_kldiv_loss = 0.0
                self.encoder.eval()
                self.decoder.eval()
                with torch.no_grad():
                    for vb, (images, _) in enumerate(valid_loader):
                        images = images.to(device)
                        z, z_mean, z_logvar = self.encoder(images)
                        reconstructions = self.decoder(z)
                        recon_loss = self._construction_loss(images, reconstructions, beta)
                        kldiv_loss = self._kl_loss(z_mean, z_logvar)
                        loss = recon_loss + kldiv_loss
                        avg_recon_loss += recon_loss
                        avg_kldiv_loss += kldiv_loss
                        avg_val_loss += loss.item()
                
                avg_val_loss /= (vb + 1)
                avg_recon_loss /= (vb + 1)
                avg_kldiv_loss /= (vb + 1)
                postfix_str = f'Loss: {avg_val_loss:0.4f}, Reconst. Loss: {avg_recon_loss:0.4f}, KL Loss: {avg_kldiv_loss:0.4f}'
                pbar.set_postfix_str(postfix_str)
                pbar.update()



## 3. Train the variational autoencoder <a name="train"></a>

In [None]:
# Create a variational autoencoder
vae = VariationalAutoencoder(input_size=(CHANNELS, IMAGE_SIZE, IMAGE_SIZE), latent_size=Z_DIM).to(device)

optimizer = torch.optim.Adam(params=vae.parameters(), lr=LEARNING_RATE)

vae.fit(
    train_loader=train_loader,
    valid_loader=valid_loader,
    epochs=EPOCHS,
    optimizer=optimizer,
    beta=BETA,
)

## 3. Reconstruct using the variational autoencoder <a name="reconstruct"></a>

In [None]:
# Select a subset of the test set
example_images, _ = next(iter(test_loader))

In [None]:
# Create autoencoder predictions and display
reconstructions = vae(example_images.to(device))
print("Example real faces")
display(example_images)
print("Reconstructions")
display(reconstructions.detach().cpu())

## 4. Latent space distribution

In [None]:
z, _, _ = vae.encoder(example_images.to(device))
z = z.detach().cpu()

x = np.linspace(-3, 3, 100)

fig = plt.figure(figsize=(20, 5))
fig.subplots_adjust(hspace=0.6, wspace=0.4)

for i in range(50):
    ax = fig.add_subplot(5, 10, i + 1)
    ax.hist(z[:, i], density=True, bins=20)
    ax.axis("off")
    ax.text(
        0.5, -0.35, str(i), fontsize=10, ha="center", transform=ax.transAxes
    )
    ax.plot(x, norm.pdf(x))

plt.show()

## 5. Generate new faces <a name="decode"></a>

In [None]:
# Sample some points in the latent space, from the standard normal distribution
grid_width, grid_height = (10, 3)
z_sample = torch.randn((grid_width * grid_height, Z_DIM))

In [None]:
# Decode the sampled points
reconstructions = vae.decoder(z_sample.to(device))
reconstructions = reconstructions.detach().cpu()

In [None]:
# Draw a plot of decoded images
fig = plt.figure(figsize=(18, 5))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

# Output the grid of faces
for i in range(grid_width * grid_height):
    ax = fig.add_subplot(grid_height, grid_width, i + 1)
    ax.axis("off")
    ax.imshow(reconstructions[i, :, :].permute(1, 2, 0))

## 6. Manipulate the images <a name="manipulate"></a>

In [None]:
# Load the label dataset
attributes = pd.read_csv(os.path.join(data_path, "celeba-dataset/list_attr_celeba.csv")).columns[1:].tolist()
attributes

In [None]:
def get_vector_of_label(label, data_loader, vae, z_dim):

    vae.eval()
    threshold = 0.05
    label_index = attributes.index(label)

    n_pos = 0
    sum_pos = torch.zeros(z_dim).to(device)
    mean_pos = torch.zeros_like(sum_pos)
    
    n_neg = 0
    sum_neg = torch.zeros(z_dim).to(device)
    mean_neg = torch.zeros_like(sum_neg)

    vec = torch.zeros(z_dim)
    
    with torch.no_grad(), trange(len(data_loader) * data_loader.batch_size, desc="Samples", unit='sample', leave=True) as pbar:
        for images, attrs in data_loader:

            z, _, _ = vae.encoder(images.to(device))

            z_pos = z[attrs[:, label_index] == 1]
            z_neg = z[attrs[:, label_index] == -1]

            if len(z_pos) > 0:
                sum_pos += torch.sum(z_pos, dim=0)
                n_pos += len(z_pos)
                current_mean_pos = sum_pos / n_pos

            if len(z_neg) > 0:
                sum_neg += torch.sum(z_neg, dim=0)
                n_neg += len(z_neg)
                current_mean_neg = sum_neg / n_neg

            vec = current_mean_pos - current_mean_neg
            dist = torch.norm(vec)
            vec /= dist

            pbar.update(len(images))
            postfix_str = f'distance: {dist.item():0.4f}'
            pbar.set_postfix_str(postfix_str, refresh=True)

            dist_pos = torch.norm(current_mean_pos - mean_pos)
            dist_neg = torch.norm(current_mean_neg - mean_neg)
            mean_pos = current_mean_pos
            mean_neg = current_mean_neg

            if dist_pos + dist_neg < threshold:
                print(f'Found the {label} vector')
                break
    
    return vec

In [None]:
def add_vector_to_images(data_loader, vae, feature_vec):
    
    vae.eval()
    with torch.no_grad():
        n_to_show = 5
        factors = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
        images, _ = next(iter(data_loader))
        
        z, _, _ = vae.encoder(images.to(device))

        fig = plt.figure(figsize=(18, 10))
        
        counter = 1

        for i in range(n_to_show):
            img = example_images[i]
            sub = fig.add_subplot(n_to_show, len(factors) + 1, counter)
            sub.axis('off')
            sub.imshow(img.permute(1, 2, 0))

            counter += 1

            for factor in factors:
                changed_z = z[i] + feature_vec * factor
                changed_image = vae.decoder(torch.unsqueeze(changed_z, dim=0)).detach().cpu()[0]
                sub = fig.add_subplot(n_to_show, len(factors) + 1, counter)
                sub.axis('off')
                sub.imshow(changed_image.permute(1, 2, 0))

                counter += 1
            
        plt.show()

In [None]:
def morph_faces(data_loader, vae):
    vae.eval()

    with torch.no_grad():
        factors = np.arange(0, 1.0, 0.1)

        images, _ = next(iter(data_loader))[:2]
        z, _, _ = vae.encoder(images.to(device))

        fig = plt.figure(figsize=(18, 8))

        counter = 1

        sub = fig.add_subplot(1, len(factors) + 2, counter)
        sub.axis('off')
        sub.imshow(images[0].permute(1, 2, 0))

        counter += 1

        for factor in factors:
            changed_z = z[0] * (1-factor) + z[1] * factor
            changed_img = vae.decoder(torch.unsqueeze(changed_z, dim=0)).detach().cpu()[0]
            sub = fig.add_subplot(1, len(factors) + 2, counter)
            sub.axis('off')
            sub.imshow(changed_img.permute(1, 2, 0))

            counter += 1
        
        sub = fig.add_subplot(1, len(factors) + 2, counter)
        sub.axis('off')
        sub.imshow(images[1].permute(1, 2, 0))

        plt.show()

In [None]:
attribute_vec = get_vector_of_label('Blond_Hair', test_loader, vae, Z_DIM)

In [None]:
add_vector_to_images(test_loader, vae, attribute_vec)

In [None]:
morph_faces(test_loader, vae)