Youtube - Aladdin Persson
Variational Autoencoder from scratch in PyTorch

[링크 텍스트](https://www.youtube.com/watch?v=VELQT1-hILo&list=LL&index=1)

In [1]:
import torch
from torch import nn

In [2]:
# Input img -> Hidden dim -> mean, std -> Parametrization trick -> Decoder -> Output img
class VariationalAutoEncoder(nn.Module):
  def __init__(self, input_dim, h_dim=200, z_dim=20):
    super().__init__()
    # encoder
    self.img_2hid = nn.Linear(input_dim, h_dim)
    self.hid_2mu = nn.Linear(h_dim, z_dim)
    self.hid_2sigma = nn.Linear(h_dim, z_dim)

    # decoder
    self.z_2hid = nn.Linear(z_dim, h_dim)
    self.hid_2img = nn.Linear(h_dim, input_dim)

    self.relu = nn.ReLU()

  def encode(self, x):
    h = self.relu(self.img_2hid(x))
    mu = self.hid_2mu(h)
    sigma = self.hid_2sigma(h)
    return mu, sigma

  def decode(self, x):
    h = self.relu(self.z_2hid(x))
    return torch.sigmoid(self.hid_2img(h))

  def forward(self, x):
    mu, sigma = self.encode(x)
    epsilon = torch.randn_like(sigma)
    z_new = mu + sigma*epsilon  # Reparametrization trick
    x_reconstructed = self.decode(z_new)
    return x_reconstructed, mu, sigma