In [None]:
import torch
import yaml
from vae import VAE
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

In [None]:
device = 'cpu'

# load mnist vae
config_path = 'VAE-Pytorch/config/vae_kl.yaml'
with open (config_path, 'r') as file:
        config = yaml.safe_load(file)

vae = VAE(config['model_params'])
vae.to(device)
pth_path = 'VAE-Pytorch/vae_kl/best_vae_kl_ckpt.pth'
vae.load_state_dict(torch.load(pth_path, map_location=device))

In [None]:
# test on mnist image
path1 = 'VAE-Pytorch/data/train/images/8/6545.png'
path2 = 'VAE-Pytorch/data/train/images/9/6167.png'
img1 = Image.open(path1).convert('L')
img2 = Image.open(path2).convert('L')

transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor()
])

img1 = transform(img1).unsqueeze(0).to(device)
img2 = transform(img2).unsqueeze(0).to(device)
out1 = vae(img1)
out2 = vae(img2)

mu1 = out1['mean']
std1 = torch.exp(0.5 * out1['log_variance'])
z1 = mu1 + std1 * torch.randn_like(std1)
mu2 = out2['mean']
std2 = torch.exp(0.5 * out2['log_variance'])
z2 = mu2 + std2 * torch.randn_like(std2)

# inteprolate between z1 and z2
interps = torch.stack([z1 + (z2 - z1) * i / 10 for i in range(11)])
gen_imgs = [vae.generate(interp) for interp in interps]

# interpolate between mu1 and mu2, and std1 and std2
interps_mu = torch.stack([mu1 + (mu2 - mu1) * i / 10 for i in range(11)])
interps_std = torch.stack([std1 + (std2 - std1) * i / 10 for i in range(11)])
interps_z =  [interps_mu[i] + interps_std[i] * torch.randn_like(interps_std[i]) for i in range(11)]
gen_imgs_z = [vae.generate(interp) for interp in interps_z]


# interpolate between mu1 and mu2 first, then std1 and std2
interps_z21 = torch.stack([interps_mu[i] + std1 * torch.randn_like(std1) for i in range(11)])
interps_z22 = torch.stack([mu2 + interps_std[i] * torch.randn_like(interps_std[i]) for i in range(11)])
interps_z2 = torch.cat((interps_z21, interps_z22), dim=0)
gen_imgs_z2 = [vae.generate(interp) for interp in interps_z2]

# interpolate only mu1 and mu2, including no std
interps_z3 = torch.stack([mu1 + (mu2 - mu1) * i / 10 for i in range(11)])
gen_imgs_z3 = [vae.generate(interp) for interp in interps_z3]

# inteprolate mu1 and mu2 with constant std=10
interps_z4 = torch.stack([mu1 + (mu2 - mu1) * i / 10 for i in range(11)])
std4 = 0.5 * torch.ones_like(std1)
interps_z4 = [interps_z4[i] + std4 * torch.randn_like(std4) for i in range(11)]
gen_imgs_z4 = [vae.generate(interp) for interp in interps_z4]

In [None]:
# plot gen_imgs
fig, axes = plt.subplots(1, 11, figsize=(20, 2))
for i, ax in enumerate(axes):
    ax.imshow(gen_imgs[i].squeeze().detach().cpu().numpy(), cmap='gray')
    ax.axis('off')
plt.show()

In [None]:
# plot gen_imgs_z
fig, axes = plt.subplots(1, 11, figsize=(20, 2))
for i, ax in enumerate(axes):
    ax.imshow(gen_imgs_z[i].squeeze().detach().cpu().numpy(), cmap='gray')
    ax.axis('off')
plt.show()

In [None]:
# plot gen_imgs_z2
fig, axes = plt.subplots(2, 11, figsize=(20, 4))
for i, ax in enumerate(axes[0]):
    ax.imshow(gen_imgs_z2[i].squeeze().detach().cpu().numpy(), cmap='gray')
    ax.axis('off')
for i, ax in enumerate(axes[1]):
    ax.imshow(gen_imgs_z2[i+11].squeeze().detach().cpu().numpy(), cmap='gray')
    ax.axis('off')
plt.show()

In [None]:
# plot gen_imgs_z3
fig, axes = plt.subplots(1, 11, figsize=(20, 2))
for i, ax in enumerate(axes):
    ax.imshow(gen_imgs_z3[i].squeeze().detach().cpu().numpy(), cmap='gray')
    ax.axis('off')
plt.show()

In [None]:
# plot gen_imgs_z4
fig, axes = plt.subplots(1, 11, figsize=(20, 2))
for i, ax in enumerate(axes):
    ax.imshow(gen_imgs_z4[i].squeeze().detach().cpu().numpy(), cmap='gray')
    ax.axis('off')
plt.show()