In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from VAE import VAE

import scanpy as sc

import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as T
from torch.optim import lr_scheduler

### Binary MNIST

In [2]:
# binary MNIST
dataset = MNIST('../data/', train=True, download=True,
            transform=T.Compose([T.ToTensor(), \
                                 T.Lambda(lambda x: torch.flatten(x)), 
                                 T.Lambda(lambda x: (x>0.5).float())]), 
            target_transform=T.Compose([T.Lambda(lambda x: torch.LongTensor([x])), 
                                        T.Lambda(lambda x: F.one_hot(x,10)), 
                                        T.Lambda(lambda x: torch.squeeze(x))]))

trainloader = DataLoader(dataset, batch_size=128, shuffle=True, drop_last=True)

# initialize conditional VAE
cvae = VAE(n_features=784, z_dim=32, layer_sizes=[128,128], generative_model='bernoulli', kl_weight=0, n_conditions=10)
optimizer = torch.optim.Adam(cvae.parameters(), lr=5e-3)

In [None]:
n_epoch = 10
for epoch in range(n_epoch):
    epoch_loss_list = []
    for x, label in trainloader:
        optimizer.zero_grad()

        _, _, loss = cvae.forward(x, condition_labels=label)

        loss.backward()
        epoch_loss_list.append(loss.detach())

        optimizer.step()

    print(f'Epoch {epoch+1} mean loss: {np.mean(epoch_loss_list):.4f}')

In [None]:
for i in range(10):
    l = np.zeros((1,10))
    l[0,i] = 1
    out = cvae.decoder(torch.zeros(1,32), condition_labels=torch.from_numpy(l).float())
    plt.figure()
    plt.imshow(out[0].detach().reshape(28,28))

In [None]:
for x, label in trainloader:
    f, ax = plt.subplots(ncols=4, figsize=(14,6))
    ax[0].imshow(x[0].detach().reshape(-1,28,28)[0])
    ax[0].set_title('input')
    out = cvae(x, None)
    ax[1].imshow(out[1][0].detach().reshape(-1,28,28)[0])
    ax[1].set_title('reconstruction probabilities')
    ax[2].imshow(out[1][0].detach().reshape(-1,28,28)[0]>0.5)
    ax[2].set_title('reconstruction thresholded')
    ax[3].imshow(out[0][0].detach().reshape(-1,28,28)[0])
    ax[3].set_title('reconstruction sampled')
    plt.show()
    break

In [None]:
X = dataset.data.reshape(-1,784).float()

labels = dataset.targets

In [None]:
_ = vae.forward(X)

In [None]:
import umap
import seaborn as sns

reducer = umap.UMAP()

z_umap = reducer.fit_transform(vae.z.detach())

sns.scatterplot(x=z_umap[:,0], y=z_umap[:,1], hue=labels)