<a href="https://colab.research.google.com/github/haruka-inb/pytorch_practice/blob/main/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Variational Autoencoder

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3
dir = 'sample_dir'

# Create a dictionary if not exists
if not os.path.exists(dir):
  os.mkdir(dir)

# Download MNIST-dataset
mnist = torchvision.datasets.MNIST(root='/../../data', train=True,
                                   transform=transforms.ToTensor(), download=True)

# Create a data loader
data_loader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /../../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 159874606.79it/s]


Extracting /../../data/MNIST/raw/train-images-idx3-ubyte.gz to /../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /../../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 28112251.99it/s]


Extracting /../../data/MNIST/raw/train-labels-idx1-ubyte.gz to /../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /../../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 51745513.70it/s]

Extracting /../../data/MNIST/raw/t10k-images-idx3-ubyte.gz to /../../data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4175915.99it/s]


Extracting /../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /../../data/MNIST/raw



In [None]:
# VAE model
class VAE(nn.Module):
  def __init__(self, image_size=784, h_dim=400, z_dim=20):
    super(VAE, self).__init__()
    self.fc1 = nn.Linear(image_size, h_dim)
    self.fc2 = nn.Linear(h_dim, z_dim)
    self.fc3 = nn.Linear(h_dim, z_dim)
    self.fc4 = nn.Linear(z_dim, h_dim)
    self.fc5 = nn.Linear(h_dim, image_size)

  def encoder(self, x):
    h = F.relu(self.fc1(x))
    return self.fc2(h), self.fc3(h)

  def reparameterize(self, mu, log_var):
    std = torch.exp(log_var/2)
    eps = torch.rand_like(std)
    return mu + std * eps

  def decode(self, z):
    h = F.relu(self.fc4(z))
    return F.sigmoid(self.fc5(h))

  def forward(self, x):
    mu, log_var = self.encoder(x)
    z = self.reparameterize(mu, log_var)
    x_reconst = self.decode(z)
    return x_reconst, mu, log_var

# Transfer the model to GPU
model = VAE().to(device)

# Define the optimizer
criteria = nn.BCELoss(size_average=False)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model



VAE(
  (fc1): Linear(in_features=784, out_features=400, bias=True)
  (fc2): Linear(in_features=400, out_features=20, bias=True)
  (fc3): Linear(in_features=400, out_features=20, bias=True)
  (fc4): Linear(in_features=20, out_features=400, bias=True)
  (fc5): Linear(in_features=400, out_features=784, bias=True)
)

In [None]:
# Start training
for e in range(num_epochs):
  for i, (images, _) in enumerate(data_loader):

    # Forward pass
    images = images.to(device).view(-1, image_size)
    x_reconst, mu, log_var = model(images)

    # Compute reconstruction loss
    # reconstruction loss ensures how the model could reconstruct input images well
    # KL Divergent ensures distributions exists in the fixed latenet space
    reconst_loss = criteria(x_reconst, images)
    kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # this is from VAE papar

    # Backpropagate loss and optimize weights
    loss = reconst_loss + kl_div
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # print epoch, step, losses
    if (i+1) % 10 == 0:
      print(f"Epoch {e+1}/{num_epochs}, Step {i+1}/{len(data_loader)}, Reconstruction Loss: {reconst_loss.item()}, KL Divergent: {kl_div.item()}")


Epoch 1/15, Step 10/469, Reconstruction Loss: 35250.26953125, KL Divergent: 4252.22021484375
Epoch 1/15, Step 20/469, Reconstruction Loss: 29001.24609375, KL Divergent: 1044.8369140625
Epoch 1/15, Step 30/469, Reconstruction Loss: 26795.5390625, KL Divergent: 755.898681640625
Epoch 1/15, Step 40/469, Reconstruction Loss: 26114.56640625, KL Divergent: 307.8574523925781
Epoch 1/15, Step 50/469, Reconstruction Loss: 25770.65234375, KL Divergent: 282.0201721191406
Epoch 1/15, Step 60/469, Reconstruction Loss: 25286.4140625, KL Divergent: 386.7532958984375
Epoch 1/15, Step 70/469, Reconstruction Loss: 22861.359375, KL Divergent: 474.0303955078125
Epoch 1/15, Step 80/469, Reconstruction Loss: 21247.828125, KL Divergent: 643.3482666015625
Epoch 1/15, Step 90/469, Reconstruction Loss: 20646.951171875, KL Divergent: 716.7455444335938
Epoch 1/15, Step 100/469, Reconstruction Loss: 19290.17578125, KL Divergent: 669.1046142578125
Epoch 1/15, Step 110/469, Reconstruction Loss: 19173.328125, KL Dive

In [None]:
# Generate images using the trained model
with torch.no_grad():

  # Save the sampled images
  z = torch.randn(batch_size, z_dim).to(device)
  outputs = model.decode(z).view(-1, 1, 28, 28)
  save_image(outputs, os.path.join(dir, 'sampled-{}.png'.format(i+1)))

  # Save the reconstructed images
  out, _, _ = model(images)
  x_concat = torch.cat([images.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
  save_image(x_concat, os.path.join(dir, 'reconst-{}.png'.format(e+1)))