# VAE
(송경우교수님 딥러닝강의 2022) https://www.youtube.com/watch?v=V-lWbJtNzTc&list=PLeiav_J6JcY8iFItzNZ_6PMlz9W4_jz5J&index=58

<img src="https://drive.google.com/uc?id=1d4eR5ta0-gOjods4S60vOyR_IUb-oLuD" height="300">

## Generative model

+ MLE의 관점에서 보았을 때, $p_\theta(x)$를 최대화 하는 파라미터를 찾는 것이다!

결국 목적함수는,   
$$\theta^* = \arg\max\limits_{\theta}\cfrac{1}{N}\sum_{i=1}^N\log{p_\theta(x_i)}$$  

이다.

#### Variational Inference
+ 어떤 조건이 주어졌을 때의 확률($p(z|x)$)을 다루기 쉬운 확률분포($q(z)$)로 근사하는 것.

<img src="https://drive.google.com/uc?id=1ryponrQU_kCjVlBcexogjeET47CrquS9" width=600>

<img src="https://drive.google.com/uc?id=1moaCYsW9Tyy0SDBFMTNZtXXTfQ3wjaZ_" width=800>

즉, log-likelihood($\log{p(x_i)}$)의 lower bound인 $L$를 maximize하면 결국에 $q_i(z)$가 $p(z|x_i)$와 가까워져, 실제 샘플에 대응하는 latent를 더 잘 뽑아줄 수 있게 된다.

이를 학습하기 위해 elbo를 maximize 하는 과정을 살펴보면,  
i번째 샘플 데이터학습할 때 마다, $q_i$로부터 $\mu_i, \Sigma_i$를 얻어내고, 여기로부터의 분포에서 다시 $\hat{x}$를 뽑아내서($\theta$), $x_i$와 닮도록 학습한다.  

이 때, 각 데이터 샘플마다 뮤와 시그마에 대응시키는 파라미터가 필요하다.  
-> 너무 많다... 새로운 데이터 들어올때마다 또 추가해야한다.

#### Amortized Variational Inference
+ $x_i$에 대응되는 $z_i$가 존재하는 상황
    + N개의 데이터가 있다면,,, 원래는 N개의 파라미터를 대응시켜서 학습시켰다..(n개의 튜플만들어서 각 튜플마다 파라미터로서 바꾸는느낌)
    + 이걸 차라리 하나의 네트워크를 통해 mapping을 시켜주겠다!!(보통 DNN에서 입력, 출력 원하는 방향으로 하듯이)

수식으로 보면, $q_i(z) \approx q(z|x_i)$ 왼쪽거 대신 오른쪽거로 VI를 하겠다는 뜻
데이터 하나마다 학습되는 과정을 수식으로 바라보면,  
$x_i$-> NN($q_\phi(z|x)$) -> $\mu(x_i), \sigma(x_i)$ -> $z=\mu(x_i) + \epsilon\sigma(x_i)$ -> NN($p_\theta(x|z)$) -> $\hat{x} \approx x$  
<img src="https://drive.google.com/uc?id=1d4eR5ta0-gOjods4S60vOyR_IUb-oLuD" height="300">

<img src="https://drive.google.com/uc?id=1y2FU4IzeWUANlmsDTDfc6rEsdfnaGJT6" height=400>

## 코드 실습

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
data_set = torchvision.datasets.MNIST('./data',
                                      train=True,
                                      transform=transforms.Compose([
                                          transforms.Resize((32, 32)),
                                          transforms.ToTensor()
                                          ]),
                                      download=True,
                                      )

In [None]:
data_set

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=(32, 32), interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
           )

In [None]:
data_loader = torch.utils.data.DataLoader(dataset=data_set, batch_size=100)

### 모델구성
+ mnist data : 28x28
+ network :
channel : 4, 8, 16, 32 /
WxH : 28-14-7-4-2 -> fc1 4 -> 1(mu), fc2 4 -> 1(var)

In [None]:
class VAE(nn.Module):

    def __init__(self,
                 in_channels,
                 latent_dim,
                 hidden_dims = None,
                 **kwargs):
        super(VAE, self).__init__()

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [16, 32, 64, 128]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU()
                )
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)

        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1]*4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i+1],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i+1]),
                    nn.LeakyReLU()
                )
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1],
                               hidden_dims[-1],
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(hidden_dims[-1], out_channels=1,
                      kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def encode(self, input):
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z):
        result = self.decoder_input(z)
        result = result.view(-1, 128, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps*std + mu

    def loss_function(self,
                      *args):
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        recons_loss = F.mse_loss(recons, input)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0)

        loss = recons_loss + kld_loss

        return loss


    def sample(self,
               num_samples,
               current_device,
               **kwargs):

        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x, **kwargs):
        return self.forward(x)[0]

In [None]:
model = VAE(in_channels=1, latent_dim=200).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.5, 0.999))

In [None]:
num = 0
for epoch in range(30):
    for i, (images, _) in enumerate(data_loader):
        # forward
        x = images.to(device)
        mu, log_var = model.encode(x)
        z = model.reparameterize(mu, log_var)

        x_rec = model.decode(z)

        # compute loss
        loss = model.loss_function(x_rec, x, mu, log_var)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"epoch : {epoch+1} | iter : {i} | loss : {loss:.4}")

        if i % 200 == 0:
            samples = model.sample(1, device)
            save_image(samples, f"./vae_samples/sample{num}.png")
        num += 1

epoch : 1 | iter : 0 | loss : 25.12
epoch : 1 | iter : 100 | loss : 0.2628
epoch : 1 | iter : 200 | loss : 0.2504
epoch : 1 | iter : 300 | loss : 0.1924
epoch : 1 | iter : 400 | loss : 0.1042
epoch : 1 | iter : 500 | loss : 0.2347
epoch : 2 | iter : 0 | loss : 0.08283
epoch : 2 | iter : 100 | loss : 0.06895
epoch : 2 | iter : 200 | loss : 0.07953
epoch : 2 | iter : 300 | loss : 0.07523
epoch : 2 | iter : 400 | loss : 0.07775
epoch : 2 | iter : 500 | loss : 0.08221
epoch : 3 | iter : 0 | loss : 0.0736
epoch : 3 | iter : 100 | loss : 0.06282
epoch : 3 | iter : 200 | loss : 0.07073
epoch : 3 | iter : 300 | loss : 0.05984
epoch : 3 | iter : 400 | loss : 0.06127
epoch : 3 | iter : 500 | loss : 0.06909
epoch : 4 | iter : 0 | loss : 0.06461
epoch : 4 | iter : 100 | loss : 0.06142
epoch : 4 | iter : 200 | loss : 0.07161
epoch : 4 | iter : 300 | loss : 0.05889
epoch : 4 | iter : 400 | loss : 0.05847
epoch : 4 | iter : 500 | loss : 0.06437
epoch : 5 | iter : 0 | loss : 0.06036
epoch : 5 | iter :

In [None]:
# class VAE(nn.Module):

#     def __init__(self,
#                  input_feature,
#                  latent_dim,
#                  hidden_dim,
#                  **kwargs):
#         super(VAE, self).__init__()


#         # Build Encoder
#         self.e_fc1 = nn.Sequential(
#             nn.Linear(input_feature, hidden_dim),
#             nn.BatchNorm2d(hidden_dim),
#             nn.LeakyReLU(0.2)
#             )

#         self.e_fc2 = nn.Sequential(
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.BatchNorm2d(hidden_dim),
#             nn.LeakyReLU(0.2)
#             )

#         # output layer
#         self.mu = nn.Linear(hidden_dim, latent_dim)
#         self.var = nn.Linear(hidden_dim, latent_dim)



#         # Build Decoder

#         self.d_fc1 = nn.Sequential(
#             nn.Linear(latent_dim, hidden_dim),
#             nn.BatchNorm2d(hidden_dim),
#             nn.LeakyReLU(0.2)
#             )

#         self.d_fc2 = nn.Sequential(
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.BatchNorm2d(hidden_dim),
#             nn.LeakyReLU(0.2)
#             )

#         self.fianl_layer = nn.Sequential(
#             nn.Linear(hidden_dim, input_feature),
#             nn.Tanh()
#         )


#     def encode(self, input):
#         result = self.e_fc1(input)
#         result = self.e_fc2(result)

#         mu = self.mu(result)
#         log_var = self.var(result)

#         return [mu, log_var]

#     def decode(self, z):
#         result = self.d_fc1(z)
#         result = self.d_fc2(result)
#         result = self.final_layer(result)
#         return result

#     def reparameterize(self, mu, logvar):
#         std = torch.exp(0.5*logvar)
#         eps = torch.randn_like(std)
#         return eps*std + mu

#     def loss_function(self,
#                       *args):
#         recons = args[0]
#         input = args[1]
#         mu = args[2]
#         log_var = args[3]

#         recons_loss = F.mse_loss(recons, input)
#         kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0)

#         loss = recons_loss + kld_loss

#         return loss


#     def sample(self,
#                num_samples,
#                current_device,
#                **kwargs):

#         z = torch.randn(num_samples,
#                         self.latent_dim)

#         z = z.to(current_device)

#         samples = self.decode(z)
#         samples = samples.view(-1, 28, 28)
#         return samples

#     def generate(self, x, **kwargs):
#         return self.forward(x)[0]