# Build a conditional VAE on CIFAR-10 that can generate images of 10 classes

[Benchmark](https://paperswithcode.com/sota/image-generation-on-cifar-10)

Ref.

In [None]:
import matplotlib.pyplot as plt

import torch
import torch_directml
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from ignite.engine import Engine, Events
from ignite.metrics import FID
import ignite.distributed as idist
from ignite.handlers import ProgressBar
from typing import Type
from tqdm.autonotebook import tqdm
from PIL import Image


for i in range(torch_directml.device_count()):
    print(i, ":", torch_directml.device_name(i))

dml = torch_directml.device(0)
print("dml =", dml)

device = dml
kwargs = {"num_workers": 1, "pin_memory": True}

## Load data

In [None]:
training_data = datasets.CIFAR10(root="data", train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.CIFAR10(root="data", train=False, download=True, transform=transforms.ToTensor())

In [None]:
# Variables
classes = training_data.classes
class_size = len(classes)
image_size = 32
chanel_num = 3

# Hyperparameters
training = True
activation = "relu"
fc_output_features = 400
latent_size = 20
batch_size = 64
epochs = 20

In [None]:
training_data.data = training_data.data.astype("uint8")
test_data.data = test_data.data.astype("uint8")

s1, s2 = random_split(training_data, [0.8, 0.2], torch.Generator().manual_seed(42))

train_loader = DataLoader(s1, batch_size=batch_size, shuffle=True, **kwargs)
validate_loader = DataLoader(s2, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, **kwargs)

## Functions

In [None]:
def plot_cifar10_images(images, labels, n):
    plt.figure(figsize=[10, 10])

    for i in range(n):
        plt.subplot(5, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(images[i])
        plt.xlabel(classes[labels[i]])

    plt.show()

## EDA

In [None]:
plot_cifar10_images(training_data.data, training_data.targets, 10)

In [None]:
print("Training images shape: ", training_data.data.shape)
print("Test images shape: ", test_data.data.shape)

## Preprocessing data

In [None]:
# Normalization
# training_data.data = training_data.data.astype("float32") / 255.0
# test_data.data = test_data.data.astype("float32") / 255.0

# training_data.data[0]

## Build cVAE model

[cVAE mechanism](https://idiotdeveloper.com/introduction-to-autoencoders/)

![cVAE mechanism](./images/variational-autoencoder.png)

Ref.

[Understanding Conditional Variational Autoencoders](https://towardsdatascience.com/understanding-conditional-variational-autoencoders-cd62b4f57bf8)

[Conditional Variational Autoencoder (cVAE) using PyTorch](https://github.com/unnir/cVAE)

[Conditional Variational Autoencoder in Keras](https://github.com/nnormandin/Conditional_VAE/blob/master/Conditional_VAE.ipynb)

[GAN Evaluation : the Frechet Inception Distance and Inception Score metrics](https://colab.research.google.com/github/pytorch-ignite/pytorch-ignite.ai/blob/gh-pages/blog/2021-08-11-GAN-evaluation-using-FID-and-IS.ipynb#scrollTo=Stp59yfH65VO)

In [None]:
class CVAE(nn.Module):
    """
    需要以下三個結構
    1. encoder
    2. reparameterizer
    3. decoder
    """

    def __init__(self, feature_size, latent_size, class_size, *args, **kwargs) -> None:
        super(CVAE, self).__init__(*args, **kwargs)

        self.feature_size = feature_size
        self.class_size = class_size

        # encoder
        self.fc1 = nn.Linear(feature_size + class_size, fc_output_features)
        self.fc21 = nn.Linear(fc_output_features, latent_size)
        self.fc22 = nn.Linear(fc_output_features, latent_size)

        # decoder
        self.fc3 = nn.Linear(latent_size + class_size, fc_output_features)
        self.fc4 = nn.Linear(fc_output_features, feature_size)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x, c):  # Q(z|x, c)
        """
        x: (bs, feature_size)
        c: (bs, class_size)
        """
        inputs = torch.cat([x, c], 1)  # (bs, feature_size + class_size)
        h1 = self.relu(self.fc1(inputs))
        z_mu = self.fc21(h1)
        z_var = self.fc22(h1)
        return z_mu, z_var

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):  # P(x|z, c)
        """
        z: (bs, latent_size)
        c: (bs, class_size)
        """
        inputs = torch.cat([z, c], 1)  # (bs, latent_size + class_size)
        h3 = self.relu(self.fc3(inputs))
        return self.sigmoid(self.fc4(h3))

    def forward(self, x, c):
        mu, logvar = self.encode(x.view(-1, self.feature_size), c)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, c), mu, logvar


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, image_size * image_size * chanel_num), reduction="sum")
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD


model = CVAE(image_size * image_size * chanel_num, latent_size, class_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


def one_hot(labels, class_size):
    targets = torch.zeros(labels.size(0), class_size)
    for i, label in enumerate(labels):
        targets[i, label] = 1
    return targets.to(device)


train_loss_history = []
test_loss_history = []
fid_scores = []


def train(epoch, data_loader: Type[DataLoader]):
    model.train()
    train_loss = 0

    for batch_idx, (data, labels) in enumerate(data_loader):
        data, labels = data.to(device), labels.to(device)
        labels = one_hot(labels, class_size)
        recon_batch, mu, logvar = model(data, labels)
        optimizer.zero_grad()
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.detach().cpu().numpy()
        optimizer.step()

        # log
        if batch_idx % 20 == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(data_loader.dataset),
                    100.0 * batch_idx / len(data_loader),
                    loss.item() / len(data),
                )
            )

    train_loss /= len(data_loader.dataset)
    train_loss_history.append(train_loss)
    print("====> Epoch: {} Average loss: {:.4f}".format(epoch, train_loss))


def test(epoch, data_loader: Type[DataLoader]):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, labels) in enumerate(data_loader):
            data, labels = data.to(device), labels.to(device)
            labels = one_hot(labels, class_size)
            recon_batch, mu, logvar = model(data, labels)
            test_loss += loss_function(recon_batch, data, mu, logvar).detach().cpu().numpy()

            if i == 0:
                n = min(data.size(0), 5)
                comparison = torch.cat([data[:n], recon_batch.view(-1, chanel_num, image_size, image_size)[:n]])
                save_image(
                    comparison.cpu(),
                    "./outputs/cifar-10-cvae-outputs/temp/reconstruction_" + str(f"{epoch:02}") + ".png",
                    nrow=n,
                )

    test_loss /= len(data_loader.dataset)
    test_loss_history.append(test_loss)
    print("====> Test set loss: {:.4f}".format(test_loss))


def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transforms.ToPILImage()(img)
        resized_img = pil_img.resize((299, 299), Image.BILINEAR)
        arr.append(transforms.ToTensor()(resized_img))
    return torch.stack(arr).reshape((len(batch), 3, 299, 299))


def evaluation(engine: Engine, batch_data):
    model.eval()

    data = batch_data[0]
    labels = batch_data[1]

    data, labels = data.to(device), labels.to(device)
    labels = one_hot(labels, class_size)
    recon_batch, mu, logvar = model(data, labels)

    n = min(data.size(0), 5)

    fake = interpolate(recon_batch.view(-1, chanel_num, image_size, image_size)[:n])
    real = interpolate(data[:n])

    return fake, real


evaluator = Engine(evaluation)
fid_metric = FID(device=idist.device())
fid_metric.attach(evaluator, "fid")

# ProgressBar().attach(evaluator)


for epoch in range(1, epochs + 1):
    train(epoch, train_loader)
    test(epoch, validate_loader if training else test_loader)
    evaluator.run(validate_loader if training else test_loader, max_epochs=1)

    metrics = evaluator.state.metrics
    fid_score = metrics["fid"]
    fid_scores.append(fid_score)

    print("====> FID score: {:.4f}".format(fid_score))

    # 在每個 epoch 測試: 用隨機取樣產生新圖片並輸出
    with torch.no_grad():
        c = torch.eye(class_size, class_size).to(device)
        sample = torch.randn(class_size, latent_size).to(device)
        sample = model.decode(sample, c).cpu()
        images = sample.view(class_size, chanel_num, image_size, image_size)
        save_image(
            images,
            "./outputs/cifar-10-cvae-outputs/temp/sample_" + str(f"{epoch:02}") + ".png",
        )

## Evaluation

In [None]:
if training:
    plt.figure(figsize=[6, 4])
    plt.plot(train_loss_history, "black", linewidth=2.0)
    plt.plot(test_loss_history, "green", linewidth=2.0)
    plt.legend(["Training Loss", "Validation Loss"], fontsize=14)
    plt.xlabel("Epochs", fontsize=10)
    plt.ylabel("Loss", fontsize=10)
    plt.title("Loss Curves", fontsize=12)
    plt.show()

    plt.figure(figsize=[6, 4])
    plt.plot(fid_scores, "red", linewidth=2.0)
    # plt.legend(["Training Loss", "Validation Loss"], fontsize=14)
    plt.xlabel("Epochs", fontsize=10)
    plt.ylabel("FID", fontsize=10)
    plt.title("FID score", fontsize=12)
    plt.show()