# Implementation of Variational AutoEncoder (VAE)

     VAE from "Auto-Encoding Variational Bayes" (2014, D.P. Kingma et. al.)
    
     Kernel-author: Jackson Kang @ Deep-learning Lab. (Handong Global University, S.Korea)
     
     Author-email:  mskang@handong.edu

     python and pytorch version: python=3.7, pytorch=1.3.1

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np

from tqdm import tqdm
from torchvision.utils import save_image

import os
import pandas as pd
from torchvision.io import read_image
from torchvision.io import ImageReadMode

from torchvision.datasets import MNIST
import torchvision.transforms.v2 as v2
from torch.utils.data import DataLoader

In [2]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path, mode=ImageReadMode.RGB)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [3]:
# Model Hyperparameters

dataset_path = '/home/daleas@ads.iu.edu/Pytorch-VAE/ship_data'
labels_path = os.path.join(dataset_path, 'labels.csv')

cuda = True
DEVICE = torch.device("cuda" if cuda else "cpu")


batch_size = 1

x_dim  = (3, 512, 512)
hidden_dim = 512
latent_dim = 32

lr = 1e-3

epochs = 3

###    Step 1. Load (or download) Dataset

In [4]:
ships_transform = v2.Compose([
    v2.Resize((128, 128)),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    v2.ToPILImage(),
    v2.ToTensor()
])

kwargs = {'num_workers': 1, 'pin_memory': True} 

my_train_data = CustomImageDataset(annotations_file=labels_path, img_dir=dataset_path, transform=ships_transform)
my_test_data = CustomImageDataset(annotations_file=labels_path, img_dir=dataset_path)


train_loader = DataLoader(dataset=my_train_data, batch_size=batch_size, shuffle=True, **kwargs)
test_loader  = DataLoader(dataset=my_test_data,  batch_size=batch_size, shuffle=False, **kwargs)



# Validate training data

for batch_idx, (x, _) in enumerate(train_loader):

    print(batch_idx, np.array(x.shape))

### Step 2. Define our model: Variational AutoEncoder (VAE)

In [5]:
"""
    A simple implementation of Gaussian MLP Encoder and Decoder
"""

class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        #self.input = nn.LazyConv2d(input_dim, ()))
        self.conv1 = nn.Conv2d(3, 128, (5,5), stride=2)
        self.conv2 = nn.Conv2d(128, 256, (5,5), stride=2)
        self.conv3 = nn.Conv2d(256, 512, (5,5), stride=2)

        self.FC_layer = nn.LazyLinear(hidden_dim)
        self.flat = nn.Flatten()

        self.FC_mean  = nn.Linear(hidden_dim, latent_dim)
        self.FC_var   = nn.Linear (hidden_dim, latent_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
        self.training = True
        
    def forward(self, input):

        c1 = F.relu(self.conv1(input))
        c2 = F.relu(self.conv2(c1))
        c3 = F.relu(self.conv3(c2))
        h_ = self.flat(c3)
        h_ = F.relu(self.FC_layer(h_)) 
        
        mean     = self.FC_mean(h_)
        log_var  = self.FC_var(h_)                     # encoder produces mean and log of variance 
                                                       #             (i.e., parateters of simple tractable normal distribution "q"
        
        return c1, c2, c3, h_, mean, log_var

In [6]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()

        self.FC_hidden = nn.Linear(latent_dim, hidden_dim*13*13)
        self.conv3 = nn.LazyConvTranspose2d(256, (5,5), stride=2)
        self.conv2 = nn.LazyConvTranspose2d(128, (5,5), stride=2, output_padding=(1, 1))
        self.conv1 = nn.LazyConvTranspose2d(3, (5,5), stride=2, output_padding=(1, 1))

        
    def forward(self, x):
        h_ = F.relu(self.FC_hidden(x))
        c3 = torch.reshape(h_, (-1, 512, 13, 13))

        c2 = F.relu(self.conv3(c3))
        c1 = F.relu(self.conv2(c2))
        img = F.relu(self.conv1(c1))

        return h_, c3, c2, c1, img
        

In [7]:
class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)        # sampling epsilon        
        z = mean + var*epsilon                          # reparameterization trick
        return z
        
                
    def forward(self, x):
        _, _, _, _, mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        _, _, _, _, x_hat = self.Decoder(z)
        
        return x_hat, mean, log_var, #ec1, ec2, ec3, eh_, dh_, dc3, dc2, dc1

In [8]:
encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)



In [None]:
model = Model(Encoder=encoder, Decoder=decoder).to(DEVICE)

### Step 3. Define Loss function (reprod. loss) and optimizer

In [None]:
from torch.optim import Adam

BCE_loss = nn.BCELoss()

def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD      = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD


optimizer = Adam(model.parameters(), lr=lr)

### Step 4. Train Variational AutoEncoder (VAE)

In [None]:
print("Start training VAE...")
model.train()

for epoch in range(epochs):
    overall_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        print(batch_idx)
        x = x.view(batch_size, 3, 128, 128)
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)
        
        # print('ec1', ec1.shape)
        # print('ec2', ec2.shape)
        # print('ec3', ec3.shape)
        # print('eh', eh_.shape)
        # print('mean', mean.shape)
        # print('log_var', log_var.shape)
        # print('dh_', dh_.shape)
        # print('dc3', dc3.shape)
        # print('dc2', dc2.shape)
        # print('dc1', dc1.shape)
        # print('img', x_hat.shape)

        loss = loss_function(x, x_hat, mean, log_var)
        
        overall_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
    print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / (batch_idx*batch_size))
    
print("Finish!!")

### Step 5. Generate images from test dataset

In [None]:
import matplotlib.pyplot as plt

In [None]:
model.eval()

with torch.no_grad():
    for batch_idx, (x, _) in enumerate(tqdm(test_loader)):
        x = x.view(batch_size, x_dim)
        x = x.to(DEVICE)
        
        x_hat, _, _ = model(x)


        break

In [None]:
def show_image(x, idx):
    x = x.view(batch_size, 28, 28)

    fig = plt.figure()
    plt.imshow(x[idx].cpu().numpy())

In [None]:
show_image(x, idx=0)

In [None]:
show_image(x_hat, idx=0)

### Step 6. Generate image from noise vector

**Please note that this is not the correct generative process.**

* Even if we don't know exact p(z|x), we can generate images from noise, since the loss function of training VAE regulates the q(z|x) (simple and tractable posteriors) must close enough to N(0, I). If q(z|x) is close to N(0, I) "enough"(but not tightly close due to posterior collapse problem), N(0, I) may replace the encoder of VAE.

* To show this, I just tested with a noise vector sampled from N(0, I) similar with Generative Adversarial Network.

In [None]:
with torch.no_grad():
    noise = torch.randn(batch_size, latent_dim).to(DEVICE)
    generated_images = decoder(noise)

In [None]:
save_image(generated_images.view(batch_size, 1, 28, 28), 'generated_sample.png')

In [None]:
show_image(generated_images, idx=12)

In [None]:
show_image(generated_images, idx=0)

In [None]:
show_image(generated_images, idx=1)

In [None]:
show_image(generated_images, idx=10)

In [None]:
show_image(generated_images, idx=20)

In [None]:
show_image(generated_images, idx=50)