# Build a conditional VAE on CIFAR-10 that can generate images of 10 classes

[Benchmark](https://paperswithcode.com/sota/image-generation-on-cifar-10)

Ref.

In [None]:
import matplotlib.pyplot as plt

import numpy as np
import torch
import torch_directml
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torchvision.models import inception_v3, Inception_V3_Weights
from scipy import linalg
from typing import Type
from ignite.engine import Engine, Events
from PIL import Image

## Setup device

In [None]:
for i in range(torch_directml.device_count()):
    print(i, ":", torch_directml.device_name(i))

# dml = torch_directml.device(1)
dml = torch_directml.device(0)
print("dml =", dml)

device = dml

## Load data

In [None]:
training_data = datasets.CIFAR10(root="data", train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.CIFAR10(root="data", train=False, download=True, transform=transforms.ToTensor())

## Variables

In [None]:
# Variables
classes = training_data.classes
class_size = len(classes)
chanel_num = 3
image_size = 32
train_loss_history = []

# Hyperparameters
batch_size = 64
epochs = 20
latent_size = 20
learning_rate = 0.001
evaluation = False

In [None]:
from torch.utils.data import DataLoader

training_data.data = training_data.data.astype("uint8")
test_data.data = test_data.data.astype("uint8")

kwargs = {"num_workers": 1, "pin_memory": True}

train_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(test_data, batch_size=1000, shuffle=False, **kwargs)

## Show some samples

In [None]:
def show_cifar10_images(images, labels):
    plt.figure(figsize=[10, 10])

    for i in range(20):
        plt.subplot(5, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(images[i])
        plt.xlabel(classes[labels[i]])

    plt.show()


show_cifar10_images(training_data.data, training_data.targets)
show_cifar10_images(test_data.data, test_data.targets)

## Data shape

In [None]:
print("Training images shape: ", training_data.data.shape)
print("Test images shape: ", test_data.data.shape)

## Build cVAE model

[cVAE mechanism](https://idiotdeveloper.com/introduction-to-autoencoders/)

![cVAE mechanism](./images/variational-autoencoder.png)

Ref.

[Understanding Conditional Variational Autoencoders](https://towardsdatascience.com/understanding-conditional-variational-autoencoders-cd62b4f57bf8)

[Conditional Variational Autoencoder (cVAE) using PyTorch](https://github.com/unnir/cVAE)

[Conditional Variational Autoencoder in Keras](https://github.com/nnormandin/Conditional_VAE/blob/master/Conditional_VAE.ipynb)

[GAN Evaluation : the Frechet Inception Distance and Inception Score metrics](https://colab.research.google.com/github/pytorch-ignite/pytorch-ignite.ai/blob/gh-pages/blog/2021-08-11-GAN-evaluation-using-FID-and-IS.ipynb#scrollTo=Stp59yfH65VO)

In [None]:
from torchinfo import summary


class CVAE(nn.Module):
    """
    主要包含三個結構
    1. encoder
    2. reparameterizer
    3. decoder
    """

    def __init__(self, feature_size, latent_size, class_size, *args, **kwargs) -> None:
        super(CVAE, self).__init__(*args, **kwargs)

        self.feature_size = feature_size
        self.class_size = class_size

        # encoder
        # self.fc1 = nn.Linear(feature_size + class_size, 400)
        self.conv_2d_1 = nn.Conv2d(3, 32, 3, 2, padding="valid")
        self.conv_2d_2 = nn.Conv2d(32, 64, 3, 2, padding="valid")

        self.fc1 = nn.Linear(7 * 7 * 64 + class_size, 400)
        self.fc21 = nn.Linear(400, latent_size)
        self.fc22 = nn.Linear(400, latent_size)

        # decoder
        self.fc3 = nn.Linear(latent_size + class_size, 400)
        # self.fc4 = nn.Linear(400, feature_size)
        self.fc4 = nn.Linear(400, 7 * 7 * 64)

        self.conv_2d_trans_1 = nn.ConvTranspose2d(64, 32, 3, 2, padding=0)
        self.conv_2d_trans_2 = nn.ConvTranspose2d(32, 3, 3, 2, padding=0)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x, c):  # Q(z|x, c)
        """
        x: (bs, feature_size)
        c: (bs, class_size)
        """

        x = self.relu(self.conv_2d_1(x))
        x = self.relu(self.conv_2d_2(x))
        x = x.view(batch_size, 64 * 7 * 7)

        inputs = torch.cat([x, c], 1)  # (bs, feature_size + class_size)

        h1 = self.relu(self.fc1(inputs))
        z_mu = self.fc21(h1)
        z_var = self.fc22(h1)
        return z_mu, z_var

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):  # P(x|z, c)
        """
        z: (bs, latent_size)
        c: (bs, class_size)
        """
        inputs = torch.cat([z, c], 1)  # (bs, latent_size + class_size)
        h3 = self.relu(self.fc3(inputs))
        h4 = self.relu(self.fc4(h3))
        h4 = h4.view(batch_size, 64, 7, 7)
        outputs = self.relu(self.conv_2d_trans_1(h4))
        outputs = self.relu(self.conv_2d_trans_2(outputs))
        outputs = self.interpolate(outputs, 32)
        outputs = outputs.view(batch_size, 32 * 32 * 3)
        return self.sigmoid(outputs)

    def forward(self, x, c):
        # mu, logvar = self.encode(x.view(-1, self.feature_size), c)
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, c), mu, logvar

    def interpolate(self, images, out_size):
        stack = []

        for img in images:
            pil_img = transforms.ToPILImage()(img)
            resized_img = pil_img.resize((out_size, out_size), Image.BILINEAR)
            stack.append(transforms.ToTensor()(resized_img))

        return torch.stack(stack)

    def generate_images(self, gen_size):
        """從隨機分布中生成指定數量的圖片"""

        with torch.no_grad():
            # 生成對稱數量的 classes (one hot)
            c = torch.eye(gen_size, class_size).to(device)

            # 隨機生成分布
            noises_z = torch.randn(gen_size, latent_size).to(device)

            samples = self.decode(noises_z, c).cpu()
            gen_images = samples.view(gen_size, chanel_num, image_size, image_size)

            return gen_images, c


model = CVAE(image_size * image_size * chanel_num, latent_size, class_size).to(device)
summary_model = CVAE(image_size * image_size * chanel_num, latent_size, class_size).to("cpu")

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

summary(
    summary_model,
    input_size=[(batch_size, chanel_num, image_size, image_size), (batch_size, class_size)],
    device="cpu",
)

## Loss function & one-hot encoding

In [None]:
def loss_function(recon_x, x, mu, logvar):
    """Reconstruction + KL divergence losses summed over all elements and batch

    See Appendix B from VAE paper:

    Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    https://arxiv.org/abs/1312.6114

    0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    """
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, image_size * image_size * chanel_num), reduction="sum")
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD


def one_hot_encode(labels, class_size):
    targets = torch.zeros(labels.size(0), class_size)
    for i, label in enumerate(labels):
        targets[i, label] = 1
    return targets.to(device)

## Training step

In [None]:
def train_step(epoch, data_loader: Type[DataLoader]):
    model.train()
    train_loss = 0

    for batch_idx, (data, labels) in enumerate(data_loader):
        data, labels = data.to(device), labels.to(device)
        labels = one_hot_encode(labels, class_size)

        optimizer.zero_grad()

        recon_batch, mu, logvar = model(data, labels)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.detach().cpu().numpy()

        optimizer.step()

        # log
        if batch_idx % 20 == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    (batch_idx + 1) * len(data),
                    len(data_loader.dataset),
                    100.0 * (batch_idx + 1) / len(data_loader),
                    loss.item() / len(data),
                ),
                end="\r",
            )

            n = min(data.size(0), 5)
            comparison = torch.cat([data[:n], recon_batch.view(-1, chanel_num, image_size, image_size)[:n]])
            save_image(
                comparison.cpu(),
                "./outputs/cifar-10-cvae-outputs/temp/gen/reconstruction_" + str(f"{epoch:02}") + ".png",
                nrow=n,
            )

    train_loss /= len(data_loader.dataset)
    train_loss_history.append(train_loss)

    print("\n====> Epoch: {} Average loss: {:.4f}".format(epoch, train_loss))

In [None]:
for epoch in range(1, epochs + 1):
    train_step(epoch, train_loader)
    samples, labels = model.generate_images(10)
    save_image(
        samples,
        "./outputs/cifar-10-cvae-outputs/temp/gen/sample_" + str(f"{epoch:02}") + ".png",
    )

In [None]:
plt.figure(figsize=[6, 4])
plt.plot(train_loss_history, "green", linewidth=2.0)
plt.legend(["Training Loss"], fontsize=14)
plt.xlabel("Epochs", fontsize=10)
plt.ylabel("Loss", fontsize=10)
plt.title("Loss Curves", fontsize=12)

In [None]:
gen_size = 10000
gen_images, gen_labels = model.generate_images(gen_size)

save_image(
    gen_images[:10],
    "./outputs/cifar-10-cvae-outputs/temp/gen/evaluation.png",
)

pil_img = []

for img in gen_images:
    pil_img.append(transforms.ToPILImage()(img))

show_cifar10_images(
    pil_img,
    np.tile(np.arange(0, class_size, 1), int(gen_size / class_size)),
)

## FID evaluation

Ref.

Inception v3 architecture

![Inception v3 architecture](./images/inception_v3.png)

[GAN in Pytorch with FID](https://www.kaggle.com/code/ibtesama/gan-in-pytorch-with-fid#References)

In [None]:
from torcheval.metrics.image import FrechetInceptionDistance


def interpolate(images):
    """
    Inception v3 model 需要輸入為 299 * 299 * 3
    所以需要先將 32 * 32 * 3 的 cifar-10 資料重新調整大小
    """
    stack = []

    for img in images:
        pil_img = transforms.ToPILImage()(img)
        resized_img = pil_img.resize((299, 299), Image.BILINEAR)
        stack.append(transforms.ToTensor()(resized_img))

    return torch.stack(stack)


fid_metric = FrechetInceptionDistance().to("cpu")


def fid_evaluation(real, gen):

    gen = interpolate(gen)
    real = interpolate(real)

    fid_metric.update(real, is_real=True)
    fid_metric.update(gen, is_real=False)
    fid = fid_metric.compute()
    fid_metric.reset()

    return fid


if evaluation:
    fids = []

    for batch_idx, (data, labels) in enumerate(test_loader):

        fid = fid_evaluation(data, gen_images[batch_idx * 1000 : (batch_idx * 1000) + 1000])
        fids.append(fid)

        print(
            "FID Evaluation: [{}/{} ({:.0f}%)]\tFID: {:.6f}".format(
                (batch_idx + 1) * len(data), len(test_loader.dataset), 100.0 * (batch_idx + 1) / len(test_loader), fid
            ),
            end="\r",
        )

    print("\nAverage FID:", np.mean(fids))