## Conditional Variational Auto Encoder (**CVAE**) - Pytorch

In [None]:
import torch
import torch.utils.data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

# Set Up GPU & Hyperparameters

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

batch_size = 64
latent_size = 20
epochs = 10

# Download Data and Set DataLoader

In [None]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor(), target_transform=None),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor(), target_transform=None),
    batch_size=batch_size, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


# Add Condition

In [None]:
# 원하는 label을 condition으로 넣어준다 (여기서는 0~9 중 하나)
def one_hot(labels, class_size):
  targets = torch.zeros(labels.size(0), class_size)
  for i, label in enumerate(labels):
    targets[i, label] = 1
  return targets.to(device)

# **Build Model**

In [None]:
# encoder: 784(28*28) -> 512 -> 256
# sampling: 256 -> 10
# decoder: 10 -> 256 -> 512 -> 784(28*28)

class CVAE(nn.Module):
  def __init__(self, feature_size,  latent_size, class_size):
    super(CVAE, self).__init__()
    self.feature_size = feature_size
    self.class_size = class_size

    #Encoder
    self.encoder = nn.Sequential(
        nn.Linear(feature_size + class_size, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU()
    )
    self.fc_mu = nn.Linear(256, latent_size) # 평균
    self.fc_var = nn.Linear(256, latent_size) # 분산

    #Decoder
    self.decoder = nn.Sequential(
        nn.Linear(latent_size + class_size, 256),
        nn.ReLU(),
        nn.Linear(256, 512),
        nn.ReLU(),
        nn.Linear(512, feature_size),
        nn.Sigmoid()
    )

  def encode(self, x, c):
    # Q(z|x, c)
    inputs = torch.cat([x, c], 1) # 1-dim으로 펴준다
    h = self.encoder(inputs)
    z_mu = self.fc_mu(h)
    z_var = self.fc_var(h)
    return z_mu, z_var
  
  def reparameterize(self, mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std
  
  def decode(self, z, c):
    # P(x|z, c)
    inputs = torch.cat([z, c], 1)
    recon = self.decoder(inputs)
    return recon
  
  def forward(self, x, c):
    mu, logvar = self.encode(x.view(-1, 28*28), c)
    z = self.reparameterize(mu, logvar)
    out = self.decode(z, c)
    return out, mu, logvar

# Loss Function & Optimizer

In [None]:
def loss_function(recon_x, x, mu, logvar):
  BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
  KLD = -0.5* torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  return BCE + KLD

In [None]:
model = CVAE(28*28, latent_size, 10).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

## Train

In [None]:
def train(epoch):
  model.train()
  train_loss = 0
  for batch_idx, (data, labels) in enumerate(train_loader):
    data, labels = data.to(device), labels.to(device)
    labels = one_hot(labels, 10) #Condition
    recon_batch, mu, logvar = model(data, labels)

    optimizer.zero_grad()
    loss = loss_function(recon_batch, data, mu, logvar)
    loss.backward()
    
    train_loss += loss.detach().cpu().numpy()
    optimizer.step()

    if batch_idx % 100 == 0:
      print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                loss.item() / len(data)))
    
      print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

## Test

In [None]:
from google.colab import drive
drive.mount('/content/gdrive/')
path = F"/content/gdrive/My Drive"

Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


In [None]:
def test(epoch):
  model.eval()
  test_loss = 0
  with torch.no_grad():
    for i, (data, labels) in enumerate(test_loader):
      data, labels = data.to(device), labels.to(device)
      labels = one_hot(labels, 10)
      recon_batch, mu, logvar = model(data, labels)
      test_loss += loss_function(recon_batch, data, mu, logvar).detach().cpu().numpy()
      if i ==0:
        n =min(data.size(0), 5)
        comparison = torch.cat([data[:n],
                                recon_batch.view(-1, 1, 28, 28)[:n]])
        save_image(comparison.cpu(),
                   'reconstruction_' + str(epoch) + '.png', nrow=n)
    
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [None]:
for epoch in range(1, epochs + 1):
        train(epoch)
        test(epoch)
        with torch.no_grad():
            c = torch.eye(10, 10).cuda()
            sample = torch.randn(10, 20).to(device)
            sample = model.decode(sample, c).cpu()
            save_image(sample.view(10, 1, 28, 28),
                       'sample_' + str(epoch) + '.png')

Train Epoch: 1 [0/60000]	Loss: 89.070122
====> Epoch: 1 Average loss: 0.0950
Train Epoch: 1 [6400/60000]	Loss: 89.744400
====> Epoch: 1 Average loss: 9.8224
Train Epoch: 1 [12800/60000]	Loss: 91.170189
====> Epoch: 1 Average loss: 19.6386
Train Epoch: 1 [19200/60000]	Loss: 87.449814
====> Epoch: 1 Average loss: 29.4030
Train Epoch: 1 [25600/60000]	Loss: 89.897621
====> Epoch: 1 Average loss: 39.2184
Train Epoch: 1 [32000/60000]	Loss: 91.265015
====> Epoch: 1 Average loss: 49.0401
Train Epoch: 1 [38400/60000]	Loss: 92.518311
====> Epoch: 1 Average loss: 58.8655
Train Epoch: 1 [44800/60000]	Loss: 92.403885
====> Epoch: 1 Average loss: 68.6851
Train Epoch: 1 [51200/60000]	Loss: 90.678917
====> Epoch: 1 Average loss: 78.4290
Train Epoch: 1 [57600/60000]	Loss: 95.128967
====> Epoch: 1 Average loss: 88.2764
====> Test set loss: 94.2502
Train Epoch: 2 [0/60000]	Loss: 92.964066
====> Epoch: 2 Average loss: 0.0992
Train Epoch: 2 [6400/60000]	Loss: 94.003601
====> Epoch: 2 Average loss: 9.9238
T