VAE viewer notebook for JetBot
===

This notebook can visualize reconstructioned image by vae. This repository using JetBot camera.

In [1]:
import PIL
import numpy as np
import cv2
import urllib
import ipywidgets.widgets as widgets
from IPython.display import display
import torch
from torchvision.transforms import transforms
from vae import VAE

## Setting Parameters

|Name | Description| Default|
|:----|:-----------|:-------|
|IMAGE_CHANNELS | Image channel such as RGB | 3 Not change|
|VARIANTS_SIZE  | Variants size of VAE      | 32          |
|MODEL_PATH     | Trained VAE model file path | ../../vae.torch|

In [2]:
IMAGE_CHANNELS = 3
VARIANTS_SIZE = 32
MODEL_PATH = '../../../vae.torch'

## Load trained VAE model
Loading trained VAE model on GPU memory

In [3]:
device = torch.device('cuda')
vae = VAE(image_channels=IMAGE_CHANNELS, z_dim=VARIANTS_SIZE)
vae.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(device)))
vae.to(device).eval()

VAE(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
    (5): ReLU()
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
    (7): ReLU()
    (8): Flatten()
  )
  (fc1): Linear(in_features=6144, out_features=32, bias=True)
  (fc2): Linear(in_features=6144, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=6144, bias=True)
  (decoder): Sequential(
    (0): UnFlatten()
    (1): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2))
    (2): ReLU()
    (3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2))
    (4): ReLU()
    (5): ConvTranspose2d(64, 32, kernel_size=(5, 5), stride=(2, 2))
    (6): ReLU()
    (7): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2))
    (8): Sigmoid()
  )
)

## Define preprocess and postprocess

In [4]:
def preprocess(image):
    observe = PIL.Image.fromarray(image)
    observe = observe.resize((160,120))
    croped = observe.crop((0, 40, 160, 120))
    tensor = transforms.ToTensor()(croped)
    return tensor

def preprocess_without_crop(image):
    observe = PIL.Image.fromarray(image)
    croped = observe.crop((0, 40, 320, 240))
    tensor = transforms.ToTensor()(croped)
    return tensor
    
def rgb8_to_jpeg(image):
    return bytes(cv2.imencode('.jpg', image)[1])

## Visualize latent space function

In [5]:
ABS_LATENT_MAX_VALUE = 3
PANEL_HEIGHT = 10
PANEL_WIDTH = 10

def sigmoid(x, gain=1, offset_x=0):
    return ((np.tanh(((x+offset_x)*gain)/2)+1)/2)

def color_bar_rgb(x):
    gain = 10
    offset_x= 0.2
    offset_green = 0.6
    x = (x * 2) - 1
    red = sigmoid(x, gain, -1*offset_x)
    blue = 1-sigmoid(x, gain, offset_x)
    green = sigmoid(x, gain, offset_green) + (1-sigmoid(x,gain,-1*offset_green))
    green = green - 1.0
    return [blue * 255,green * 255,red * 255]

def _get_color(value):
    t = (value + ABS_LATENT_MAX_VALUE) / (ABS_LATENT_MAX_VALUE * 2.0)
    color = color_bar_rgb(t)
    return color

def create_color_panel(latent_spaces):
    images = []
    for z in latent_spaces:
        p = np.zeros((PANEL_HEIGHT, PANEL_WIDTH, 3))
        color = _get_color(z)
        p += color[::-1]
        p = np.clip(p, 0, 255)
        images.append(p)
    panel = np.concatenate(images, axis=1)
    return panel

# Set Jetbot IP and port

In [9]:
jetbot_ip = 'localhost'
jetbot_port = 8080

# Create GUI

In [7]:
resize = widgets.Image(format='jpeg', width=160, height=80)
result = widgets.Image(format='jpeg', width=160, height=80)
color_bar = widgets.Image(format='jpeg', width=32*PANEL_WIDTH, height=10*PANEL_HEIGHT)
display(widgets.HBox([resize,result]))
display(color_bar)

HBox(children=(Image(value=b'', format='jpeg', height='80', width='160'), Image(value=b'', format='jpeg', heig…

Image(value=b'', format='jpeg', height='100', width='320')

## Start cell over and over to see reconstruction results

In [13]:
def vae_process(img):
    pil_image = preprocess_without_crop(img)
    preprocessed_image = preprocess(img)
    resize.value = rgb8_to_jpeg(np.transpose(np.uint8(preprocessed_image*255),[1,2,0]))
    z, _ ,_ = vae.encode(torch.unsqueeze(preprocessed_image,dim=0).to(device))
    reconst = vae.decode(z)
    reconst = reconst.detach().cpu()[0].numpy()
    # Change reconstruction image to RGB
    reconst = np.transpose(np.uint8(reconst*255),[1,2,0])[:,:,::-1]
    result.value = rgb8_to_jpeg(reconst)
    latent_space = z.detach().cpu().numpy()[0]
    color_bar.value = rgb8_to_jpeg(create_color_panel(latent_space))

req = urllib.request.urlopen('http://{0}:{1}/camera'.format(jetbot_ip,jetbot_port))
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
vae_process(img)