In [None]:
# import statements

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import yaml

from models import beta_vae
from experiment import VAEXperiment

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Load CelebA

to_tensor = torchvision.transforms.ToTensor()
downsize = torchvision.transforms.Resize((64, 64))
composed_transform = torchvision.transforms.Compose([downsize, to_tensor])
root = "" # path to CelebA dataset
trainset = torchvision.datasets.CelebA(root=root, split='train', download=True, transform=composed_transform)
trainset_abridged = torch.utils.data.Subset(trainset, range(2000)) # 2000 images
batch_size = 16
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
# Initialize beta_vae

in_channels = 3
latent_dim = 32
loss_type = 'H'
beta = 10.0

model = beta_vae.BetaVAE(in_channels=in_channels, latent_dim=latent_dim, loss_type=loss_type, beta=beta).to(device)

# load params

param_path = "/PyTorch-VAE/configs/bhvae.yaml"
with open(param_path, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

params = config['exp_params']

vae = VAEXperiment(vae_model=model, params=params)

# Load checkpoint

checkpoint_path = "" # fill in path to checkpoint

# Load weights from state_dict

model = vae.load_from_checkpoint(checkpoint_path, vae_model=model, params=params)

In [None]:
# Draw N random points from training set

X, _ = next(iter(trainloader))

X = X.to(device)

In [None]:
def show(img):
    npimg = img.numpy()
    plt.figure(figsize=(20,5))
    # no ticks
    plt.xticks([])
    plt.yticks([])
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

In [None]:
# Visualize first 10 elements of X on subplots

X_viz = X[:8].cpu()

# Visualize X_viz on torch grid

grid = torchvision.utils.make_grid(X_viz.reshape(-1,3,64,64), nrow=10)
show(grid)

In [None]:
# Visualize reconstructions

X_rec = model(X[:8])[0].detach().cpu()

# Visualize X_rec on torch grid

grid = torchvision.utils.make_grid(X_rec.reshape(-1,3,64,64), nrow=10)
show(grid)

In [None]:
# Compute latent for X[0]

mu, logvar = model.model.encode(X[0].unsqueeze(0))

In [None]:
# Traverse latent in [-3,3] for dimension d

d_list = [1, 2, 11]
n_traversals = 10
z = mu

# Replace d-th dimension of z with traversal of [-3,3] for each d in d_list

z_traversals = torch.zeros(n_traversals*len(d_list), latent_dim).to(device)
for i, d in enumerate(d_list):
    for j, val in enumerate(torch.linspace(-3, 3, n_traversals)):
        z_traversals[i*n_traversals+j] = z
        z_traversals[i*n_traversals+j][d] = val

# Decode z_traversals

X_traversals = model.model.decode(z_traversals).detach().cpu()

# Visualize X_traversals on torch grid

grid = torchvision.utils.make_grid(X_traversals.reshape(-1,3,64,64), nrow=10)
show(grid)

# Save figure

plt.savefig('../results/beta_vae_traversals.png', bbox_inches='tight', dpi=300)