# VAE (Variational Auto Encoder)

## I. 모델 이해

#### 목표: Generative Model
* 데이터 X의 잠재변수 z(Latent Vector)를 도출하고 
* z의 random sample 을 통해 데이터 X의 분포를 추정함으로써 분포로부터 유사한 결과를 생성

<img src="../../shared/VAE_intuition.png" alt="Drawing" style="width: 500px;" align="left"/>

#### 구성: Encoder + Decoder

* **Encoder**: Recognition Model 의 역할을 한다고 해석할 수 있음
* **Decoder**: Generative Model 의 역할을 한다고 해석할 수 있고, "실제 VAE의 목표"가 되는 부분

<img src="../../shared/VAE_step.png" alt="Drawing" style="width: 500px;"/>

#### 과정:

(1) **Variational Inference** - 데이터 X로부터의 잠재변수 z를 decode했을 때 "실제와 유사한" 데이터를 잘 생성하는 z의 분포($p(z|x)$) 를 찾는 목표의 대안으로.. 
* $ p(z|x) = {{p(x|z)p(z)}\over{p(x)}}$ 를 찾기 위해 실제 "모든" 데이터 X의 분포 $ P(x) $를 알아야하는데 이는 train 데이터만으로 알 수 없음 ($ \; p(x) = \int{p(x|z)p(z)dz} \;$ 계산 불가능)
* 따라서 어떤 "계산 가능한" 확률분포 $ q(z|x) $를 두어 $ q(z|x) $ 가 $ p(z|x) $에 근사하도록 학습 (Variational Inference)

(2) **argmax$ELBO(\phi)$** - $ q(z|x)$의 모수(parameter) $ \phi $를 조정하여 min$KL(q(z|x)||p(z|x))$ 가 되게 하는 모수를 찾는 목표의 대안으로..
* $p(x)$를 알기 위해 $log(p(x))$를 계산해보면 $log(p(x)) = ELBO(\phi) + KL(q(z|x)||p(z|x))$ 형태로 표현됨
* KL-divergence 를 최소화하는 $q(z|x)$의 모수 $\phi$ 를 찾으면 되는데 $(p(z|x))$를 모르기 때문에 KL-term을 최소화하는 대신 ELBO-term을 최대화하는 $\phi$를 찾는 것이 목표

(3) **Reparameterization Trick** - ELBO-term을 극대화하는 학습을 할 때 Backpropagation을 가능하게 하기 위해
* feed forward과정 내에 있는 $q(z|x)$분포로부터의 z를 sampling 하는것은 미분이 가능한 연산이 아니므로 BP 불가
* 따라서 $ z = \mu + logVar $ (non-deterministic하므로)가 아닌 $ z = \mu + \epsilon*(logVar) $ 로 변형 (eps: $N(0,1)$로 부터의 random sampling)
* ELBO-term $= E_{q(z|x)}[log(p(x|z))] - KL(q_{\phi}(z|x_i)||p(z))$ 에서 

(4) **Maximum Likelihood Estimation** - Decoder를 통해 잠재변수 z로부터 X와 근사한 분포를 추정(Reconstruct)하고 Encoder를 통해 X로부터 추출한 z가 사전분포 p(z)와 근사하도록 추정(Regularization)

* **Reconstruction Error**(ELBO의 첫번째 term): $g_\theta(z)$ 와 데이터 $X$의 분포를 최대로 유사하는 방향으로 학습하여 $ E_{q(z|x)}[log(p(x|z))]$ 를 최대화
    * 방법1: ** $g_\theta(z)$의 결과를 Bernnoulli Distribution 으로 가정**하여 $p_\theta(x_i|z^i)$를 도출하고, 수식($log(p_\theta(x_i|z^i)$)을 정리하면 $p_{i,j}$와 $X_{i,j}$의 **Cross Entropy 형태**
    * 방법2: ** $g_\theta(z)$의 결과를 Gaussian Distribution 으로 가정**하여 $\mu, \sigma$를 도출하고, 수식($log(p_\theta(x_i|z^i)$)을 정리하면 **$\mu_{i,j}$와 $X_{i,j}$의 MSE 형태**
* **Regularization Error**(ELBO의 두번째 term): $q_{\phi}(z|x_i)$를 "아는 사전 분포" $z$~$N(0,1)$와 유사한 방향으로 학습하여 $KL(q_{\phi}(z|x_i)||p(z))$ 를 최소화

<img src="../../shared/VAE_loss.png" alt="Drawing" style="width: 500px;"/>

<img src="../../shared/VAE_loss-all.png" alt="Drawing"/>

<hr>

## II. MNIST 를 통한 예시

#### (0) Define Hyper-parameters / Helper Function

In [1]:
import torch
import os

In [2]:
# Device Configuration for Where the Tensors Be Operated
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define OS Configuration
sample_dir = './results'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20

num_epochs = 20
batch_size = 128
learning_rate = 1e-3

#### (1) Load Data

In [3]:
import torchvision # To Download MNIST Datasets from Torch 
import torchvision.transforms as transforms # To Transform MNIST "Images" to "Tensor"

In [4]:
train_data = torchvision.datasets.MNIST(root='./datasets',
                                        train=True,
                                        transform=transforms.ToTensor(),
                                        download=True)

# Doesn't Need Test Data (Going to be Sampled from z~N(0,1))

#### (2) Define Dataloader

In [5]:
train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

# Doesn't Need Test Loader As Well

In [6]:
# cf) check how data_loader works
image, label = next(iter(train_loader))
print(image.size(), ": [Batch, Channel, Height, Width] Respectively")

torch.Size([128, 1, 28, 28]) : [Batch, Channel, Height, Width] Respectively


#### (3) Define Model

In [7]:
import torch.nn as nn
import torch.nn.functional as F

In [8]:
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=h_dim, z_dim=z_dim):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim) # from 784 Nodes(28x28 MNIST Image) to 400 Nodes (h_dim) 
        self.fc2 = nn.Linear(h_dim, z_dim) # from 400 Nodes (h_dim) to 20 Nodes (Dims of mean of z)
        self.fc3 = nn.Linear(h_dim, z_dim) # from 400 Nodes (h_dim) to 20 Nodes (Dims of std of z)
        self.fc4 = nn.Linear(z_dim, h_dim) # from 20 Nodes (reparameterized z=mean+eps*std) to 400 Nodes (h_dim)
        self.fc5 = nn.Linear(h_dim, image_size) # from 400 Nodes (h_dim) to 784 Nodes (Reconstructed 28x28 Image)
        
    # Encoder: Encode Image to Latent Vector z
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    # Reparameterize z=mean+std to z=mean+esp*std
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    # Decoder: Decode Reparameterized Latent Vector z to Reconstructed Image
    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    # Feed Forward the Process and Outputs Estimated (Mean, Std, Reconstructed_Image) at the same time
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

model = VAE().to(device)

#### (4) Set Loss & Optimizer

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Total Loss is going to be defined in Training Part as it is a combination of Reconstruction Loss and Regularization Loss

#### (5) Train / Test

In [10]:
# Load 'save_image' Function
from torchvision.utils import save_image

In [11]:
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(train_loader): # '_' as we don't need label of the input Image
        # Feed Forward
        x = x.to(device).view(-1, image_size) # Flatten 2D Image into 1D Nodes
        x_reconst, mu, log_var = model(x)
        
        # Compute the Total Loss
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False) # See the Description below
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        
        # Get Loss, Compute Gradient, Update Parameters
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Print Loss for Tracking Training
        if (i+1) % 50 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(train_loader), reconst_loss.item(), kl_div.item()))
            
    # Save Model on Last epoch
    if epoch+1 == num_epochs:
        torch.save(model.state_dict(), './model.pth')
    
    # Save Generated Image and Reconstructed Image at every Epoch
    with torch.no_grad():
        # Save the sampled images
        z = torch.randn(batch_size, z_dim).to(device) # Randomly Sample z
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))



Epoch[1/20], Step [50/469], Reconst Loss: 27043.2285, KL Div: 689.4911
Epoch[1/20], Step [100/469], Reconst Loss: 21034.1328, KL Div: 1415.5076
Epoch[1/20], Step [150/469], Reconst Loss: 19505.9648, KL Div: 1832.9268
Epoch[1/20], Step [200/469], Reconst Loss: 16867.9688, KL Div: 2097.4224
Epoch[1/20], Step [250/469], Reconst Loss: 16869.5898, KL Div: 2308.5818
Epoch[1/20], Step [300/469], Reconst Loss: 16147.5137, KL Div: 2325.8838
Epoch[1/20], Step [350/469], Reconst Loss: 15486.3750, KL Div: 2537.8840
Epoch[1/20], Step [400/469], Reconst Loss: 14367.6504, KL Div: 2451.7480
Epoch[1/20], Step [450/469], Reconst Loss: 13675.8486, KL Div: 2696.8506
Epoch[2/20], Step [50/469], Reconst Loss: 12412.6924, KL Div: 2756.2368
Epoch[2/20], Step [100/469], Reconst Loss: 12826.1221, KL Div: 2783.3357
Epoch[2/20], Step [150/469], Reconst Loss: 12700.5537, KL Div: 2839.3696
Epoch[2/20], Step [200/469], Reconst Loss: 12788.1787, KL Div: 2835.9673
Epoch[2/20], Step [250/469], Reconst Loss: 11997.3789,

Epoch[13/20], Step [300/469], Reconst Loss: 10113.7803, KL Div: 3165.8167
Epoch[13/20], Step [350/469], Reconst Loss: 10065.7686, KL Div: 3288.4292
Epoch[13/20], Step [400/469], Reconst Loss: 10125.1221, KL Div: 3114.6873
Epoch[13/20], Step [450/469], Reconst Loss: 10762.6406, KL Div: 3220.4175
Epoch[14/20], Step [50/469], Reconst Loss: 10384.0479, KL Div: 3305.6655
Epoch[14/20], Step [100/469], Reconst Loss: 10105.0029, KL Div: 3344.3110
Epoch[14/20], Step [150/469], Reconst Loss: 10255.1602, KL Div: 3266.5461
Epoch[14/20], Step [200/469], Reconst Loss: 10026.7686, KL Div: 3295.8701
Epoch[14/20], Step [250/469], Reconst Loss: 10476.5791, KL Div: 3181.1997
Epoch[14/20], Step [300/469], Reconst Loss: 10312.2764, KL Div: 3229.7793
Epoch[14/20], Step [350/469], Reconst Loss: 10324.7256, KL Div: 3220.8438
Epoch[14/20], Step [400/469], Reconst Loss: 10080.6201, KL Div: 3233.8140
Epoch[14/20], Step [450/469], Reconst Loss: 10330.5430, KL Div: 3276.0830
Epoch[15/20], Step [50/469], Reconst Lo

#### <a href="https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#binary_cross_entropy" target="_blank">cf) Pytorch for F.binaray_cross_entropy</a>
size_average: default=True (Loss 를 element 크기로 평균) / False로 선언하면 Loss 는 각 element의 loss의 합

<hr>

## III. 시각화를 통한 이해

#### 1. 예제에서 선언한 20차원의 Multivarient Gaussian Distribution 으로 표현한 MNIST 데이터 분포

<img src="../../shared/VAE_MNIST_output.png" alt="Drawing" style="width: 800px;" align="left"/>

#### 2. MNIST 분포 추정 과정 [[link](https://github.com/gamchanr/TA-EE4178/blob/master/shared/VAE_MNIST_simulation.gif)]

<img src="../../shared/VAE_MNIST_simulation.gif" alt="Drawing" style="width: 500px;"/>

<hr>

## reference

* [오토인코더의 모든 것](https://www.youtube.com/watch?v=o_peo6U7IRM)<br>
* [Jeremy Jordan VAE](https://www.jeremyjordan.me/variational-autoencoders/)<br>
* [Joseph Rocca VAE](https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73) <br>
* [Taeu Kim VAE paper review](https://taeu.github.io/paper/deeplearning-paper-vae/)
* [Ratsgo VAE](https://ratsgo.github.io/generative%20model/2018/01/27/VAE/)
* [Multivariate Gaussian Distribution](https://www.sallys.space/blog/2018/03/20/multivariate-gaussian/)
* [BCE](https://curt-park.github.io/2018-09-19/loss-cross-entropy/)
* [Pytorch Official VAE Tutorial](https://github.com/pytorch/examples/blob/master/vae/main.py)