<a href="https://colab.research.google.com/github/fhswf/ki-wir/blob/main/VAEM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Variational Autoencoder


In [None]:
# Setup on Colab
!pip install gradio &> /dev/null
!pip install pytorch_lightning &> /dev/null
!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash 
!sudo apt-get install git-lfs
!git lfs install

!if [ ! -e ki_wir ]; then git clone https://github.com/fhswf/ki-wir.git ki_wir; else cd ki_wir; git pull; fi
!cd ki_wir; git lfs fetch
!cd ki_wir; git lfs checkout


In [1]:
import gradio as gr
import torch
import numpy as np
import ki_wir.models.vanilla_vae as vanilla_vae
import ki_wir.models.logcosh_vae as logcosh_vae
import ki_wir.models.dfcvae as dfc_vae
import ki_wir.models.experiment as experiment

from PIL import Image
from torchvision import transforms
import torchvision.utils as vutils

In [2]:
%env CUDA_VISIBLE_DEVICES=1
device = torch.device("cuda:0")

env: CUDA_VISIBLE_DEVICES=1


In [4]:
params={"in_channels": 3, "latent_dim": 128, "img_size": 64}
config = { "DFC": [ dfc_vae.DFCVAE, "ki_wir/pretrained/dfc.ckpt" ], \
          "LogCosh": [ logcosh_vae.LogCoshVAE, "ki_wir/pretrained/logcosh.ckpt" ], \
          "Vanilla": [ vanilla_vae.VanillaVAE, "ki_wir/pretrained/vanilla.ckpt" ] }

In [5]:
models = {}

for m, c in config.items():
    model = c[0](**params)
    exp = experiment.VAEXperiment(model, params)
    exp.load_from_checkpoint(c[1], vae_model=model, params=params).to(device)
    models[m] = exp 

In [6]:
models

{'DFC': VAEXperiment(
   (model): DFCVAE(
     (encoder): Sequential(
       (0): Sequential(
         (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
         (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (2): LeakyReLU(negative_slope=0.01)
       )
       (1): Sequential(
         (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
         (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (2): LeakyReLU(negative_slope=0.01)
       )
       (2): Sequential(
         (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
         (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (2): LeakyReLU(negative_slope=0.01)
       )
       (3): Sequential(
         (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
         (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True

In [23]:
def reconstruct(name, image1, image2, alpha):
  SetRange = transforms.Lambda(lambda X: 2 * X - 1.)
  img1 = Image.fromarray(image1)
  img2 = Image.fromarray(image2)  
  img1 = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        ])(img1)
  img2 = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        ])(img2)

  img = alpha*img2 + (1-alpha)*img1
  orig = transforms.ToPILImage(mode='RGB')(img)
  img = SetRange(img)

  #img = torch.moveaxis(img, 0, -1)
  img = torch.unsqueeze(img.cuda(), 0)
  dec = models[name].model.generate(img, latent_dim=128)
 
  dec = torch.squeeze(dec[0], 0)
  dec = transforms.Lambda(lambda X: 0.5 * (X + 1.))(dec)

  return transforms.ToPILImage(mode='RGB')(dec)

In [8]:
test_label = ""

for name in models.keys():
    exp = models[name]
    exp.curr_device = device
    samples = exp.model.sample(144, device)
    vutils.save_image(samples.cpu().data,
                      f"sample_{name}.png",
                      normalize=True,
                      nrow=12)

In [None]:
model = gr.inputs.Dropdown(list(models.keys()), type="value", default=None, label="Model")
alpha = gr.inputs.Slider(minimum=0, maximum=1.0, step=0.1, default=0, label=None)
out1 = gr.outputs.Image(type="auto", label="original")
out2 = gr.outputs.Image(type="auto", label="reconstructed")
iface = gr.Interface(fn=reconstruct, layout="vertical", inputs=[model, "image", "image", alpha], outputs=out1).launch(debug=True, share=True)


This share link will expire in 72 hours. To get longer links, send an email to: support@gradio.app
