In [None]:
import torch
import yaml
from vae import VAE
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

In [None]:
device = 'cpu'

# load mnist vae
config_path = 'VAE-Pytorch/config/vae_kl.yaml'
with open (config_path, 'r') as file:
        config = yaml.safe_load(file)

vae = VAE(config['model_params'])
vae.to(device)
pth_path = 'VAE-Pytorch/vae_kl/best_vae_kl_ckpt.pth'
vae.load_state_dict(torch.load(pth_path, map_location=device))

In [None]:
# test on mnist image
path1 = 'VAE-Pytorch/data/train/images/8/6545.png'
path2 = 'VAE-Pytorch/data/train/images/9/6167.png'
img1 = Image.open(path1).convert('L')
img2 = Image.open(path2).convert('L')

transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor()
])

img1 = transform(img1).unsqueeze(0).to(device)
img2 = transform(img2).unsqueeze(0).to(device)
out1 = vae(img1)
out2 = vae(img2)

mu1 = out1['mean']
std1 = torch.exp(0.5 * out1['log_variance'])
z1 = mu1 + std1 * torch.randn_like(std1)
mu2 = out2['mean']
std2 = torch.exp(0.5 * out2['log_variance'])
z2 = mu2 + std2 * torch.randn_like(std2)

# inteprolate between z1 and z2
interps = torch.stack([z1 + (z2 - z1) * i / 10 for i in range(11)])
gen_imgs = [vae.generate(interp) for interp in interps]



In [None]:
# plot gen_imgs
fig, axes = plt.subplots(1, 11, figsize=(20, 2))
for i, ax in enumerate(axes):
    ax.imshow(gen_imgs[i].squeeze().detach().cpu().numpy(), cmap='gray')
    ax.axis('off')
plt.show()