### Conditional Variational Autoencoder(CVAE)
Conditional Variational Autoencoder(CVAE)[1] is the extension of Variational Autoencoder(VAE)[2]

The objective function in the Vanilla VAE is
$$
\log P ( X ) - D _ { K L } [ Q ( z | X ) \| P ( z | X ) ] = E [ \log P ( X | z ) ] - D _ { K L } [ Q ( z | X ) \| P ( z ) ]
$$
In Conditional- VAE, the encoder is Q(z|X, y), while the decoder is P(X|z, y)。The objective function above can be modified to be
$$
\log P ( X | y ) - D _ { K L } [ Q ( z | X ,y ) \| P ( z | X ,y ) ] = E [ \log P ( X | z ,y ) ] - D _ { K L } [ Q ( z | X ,y ) \| P ( z | y ) ]
$$


- - -
[1]: Sohn, Kihyuk, Honglak Lee, and Xinchen Yan. “Learning Structured Output Representation using Deep Conditional Generative Models.” Advances in Neural Information Processing Systems. 2015.

[2]: Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).

In [1]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import numpy as np
from torch.autograd import Variable

no_cuda = False
cuda_available = not no_cuda and torch.cuda.is_available()

BATCH_SIZE =16
EPOCH = 100
SEED = 8

torch.manual_seed(SEED)

device = torch.device("cuda" if cuda_available else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda_available else {}
train_loader = torch.utils.data.DataLoader( datasets.MNIST('./MNIST_data', train=True, download=True,                   transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(datasets.MNIST('./MNIST_data', train=False, transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)


print(test_loader.dataset.targets)

tensor([7, 2, 1,  ..., 4, 5, 6])


In [2]:
class CVAE(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc1 = nn.Linear(794, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(30, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x, y):        
        #y = y.type(torch.cuda.FloatTensor if cuda_available else torch.FloatTensor).unsqueeze(1)
        
        y = y.type(torch.cuda.FloatTensor if cuda_available else torch.FloatTensor)

        h1 = F.relu(self.fc1(torch.cat((x, y), 1)))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        epsilon = torch.randn_like(std)
        return mu + epsilon * std

    def decode(self, z, y):
        #y = y.type(torch.cuda.FloatTensor if cuda_available else torch.FloatTensor).unsqueeze(1)
        y = y.type(torch.cuda.FloatTensor if cuda_available else torch.FloatTensor)

        h3 = F.relu(self.fc3(torch.cat((z, y), 1)))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x, y):
        mu, logvar = self.encode(x.view(-1, 784), y)
        
        z = self.reparameterize(mu, logvar)
        return self.decode(z, y), mu, logvar

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    
    BCE = F.binary_cross_entropy(recon_x, x, reduction = 'sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

model = CVAE().to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [3]:
def idx2onehot(idx, n):
    assert torch.max(idx).item() < n and idx.dim()<=2
    idx2dim = idx.view(-1,1) # change from 1-dim tensor to 2-dim tensor
    onehot = torch.zeros(idx2dim.size(0),n).scatter_(1,idx2dim,1)

    return onehot



def train(epoch):
    model.train()
    train_loss = 0

    for batch_idx, (data, label) in enumerate(train_loader):
     
        data = data.round().to(device) #[64, 1, 28, 28]
        label = idx2onehot(label,10).to(device)
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data, label)
        
        loss = loss_function(recon_batch, data.view(-1, data.shape[2]*data.shape[3]), mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))





In [None]:
def test(epoch):
    
    model.eval()
    
    test_loss = 0

    with torch.no_grad():
        
        for i, (data, label) in enumerate(test_loader):
            data = data.round().to(device)
            label = idx2onehot(label,10).to(device)

            recon_batch, mu, logvar = model(data, label)                                    
            test_loss += loss_function(recon_batch, data.view(-1, data.shape[2]*data.shape[3]), mu, logvar).item()

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))


for epoch in range(1, EPOCH + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        
        # sampling
        sample = torch.randn(16, 20).to(device)
      
        c = torch.zeros(sample.shape[0],1).fill_(5).type(torch.LongTensor)
      
        c = idx2onehot(c,10).to(device)

        sample = model.decode(sample, c).cpu()

        generated_image = sample.round()
        image_save_path = 'images'
        save_image(generated_image.view(16, 1, 28, 28),os.path.join(image_save_path,'sample_{}.png'.format(str(epoch))))



====> Epoch: 1 Average loss: 77.7605
====> Test set loss: 80.3823
====> Epoch: 2 Average loss: 77.7095
====> Test set loss: 80.2210
====> Epoch: 3 Average loss: 77.6666
====> Test set loss: 80.2843
====> Epoch: 4 Average loss: 77.6426
====> Test set loss: 80.4796
====> Epoch: 5 Average loss: 77.6790
====> Test set loss: 79.8406
