# C-VAE for MNIST

For further information read the [Conditional Variational Autoencoder tutorial](https://wiseodd.github.io/techblog/2016/12/17/conditional-vae/).

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set Hyper-parameters (change None)
BATCH_SIZE = None
LEARNING_RATE = None
N_EPOCH = None


In [None]:
# MNIST Dataset
original_train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
original_test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=original_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=original_test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
class CVAE(nn.Module):
    def __init__(self, x_dim, z_dim, c_dim):
        super(CVAE, self).__init__()
        
        #######################################
        ##       Define Encoder layers       ##
        ## use linear or convolutional layer ##
        #######################################
        
        #######################################
        ##       Define Decoder layers       ##
        ## use linear or convolutional layer ##
        #######################################
    
    def encoder(self, x, c):

        return None
    
    def decoder(self, z, c):

        return None
    
    def sampling(self, mu, log_var):
        std = None
        eps = None
        return eps.mul(std).add(mu)
    
    
    def forward(self, x, c):
        
        return None

In [None]:
# Create Model (change None)
cond_dim = None
latent_dim = None
cvae = CVAE(x_dim=None, z_dim=latent_dim, c_dim=cond_dim)

# Device setting
cvae = cvae.to(device)

In [None]:
# Your Model
cvae

In [None]:
optimizer = optim.Adam(cvae.parameters(), lr=LEARNING_RATE)

# return reconstruction error + KL divergence losses
def loss_function(recon_x, x, mu, log_var):
    kl_loss = None
    recon_loss = None
    return 1 * kl_loss + 1 * recon_loss #You can change constants

In [None]:
# Train
for epoch in range(1, N_EPOCH + 1):
    cvae.train()
    train_loss = 0
    for (data, cond) in train_loader:
        data = data.to(device)
        
        cond = None # create one-hot condition
        cond = conde.to(device)
        
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = cvae(data, cond)
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print('Epoch: {}/{}\t Average loss: {:.4f}'.format(epoch, N_EPOCH, train_loss / len(train_loader.dataset)))

In [None]:
# Visualization
import matplotlib.pyplot as plt

digit_size = 28
z_sample = np.random.rand(1, latent_dim) # random 
plt.figure(figsize=(20, 1))

for i in range(10):
    c = np.zeros((1, cond_dim))
    c[0][i] = 1
    
    cvae.eval()
    with torch.no_grad():
        img = cvae.decoder(z_sample, c)
        # reshape (if needed)

    plt.subplot(1, i+1)
    plt.axis('off')
    plt.imshow(img, cmap='Greys_r',)
plt.show()


# UMAP
These links help you to understand how UMAP works.

[Scanpy anndata](https://anndata.readthedocs.io/en/stable/)

[Scanpy umap](https://icb-scanpy.readthedocs-hosted.com/en/stable/api/scanpy.pl.umap.html)

[Example](https://icb-scanpy-tutorials.readthedocs-hosted.com/en/latest/visualizing-marker-genes.html)

In [None]:
# Create latent space and labels for UMAP
cvae.eval()
latent = None
labels = None
with torch.no_grad():
    for data, cond in test_loader:
        data = data.to(device)
        
        if labels is None:
            latent = cond
        else:
            # concatenate labels and cond
            
        cond = None # create one-hot condition
        cond = conde.to(device)
        
        batch_latent = cvae.encoder(data, cond)
        if latent is None:
            latent = batch_latent
        else:
            # concatenate latent and batch_latent


In [None]:
# UMAP
import scanpy as sc

labels = labels.astype(str)
latent_anndata = sc.AnnData(X=latent,
                        obs={"Numbers": label})
sc.pp.neighbors(latent_anndata)
sc.tl.umap(latent_anndata)

# Visualization
sc.pl.umap(latent_anndata, color=["Numbers"],
           frameon=False,
           legend_loc=False,
           show=True)