In this notebook we will train a Variational Auto-Encoder VAE, which will (hopefully) translate all UI images into some n-dimensional space, in which we can then do the optimising.

We will use [https://github.com/AntixK/PyTorch-VAE](https://github.com/AntixK/PyTorch-VAE/tree/master) as a basis.

In [1]:
from typing import List, Callable, Union, Any, TypeVar, Tuple
from os import listdir
from os.path import isfile, join
import json
from utils import get_all_bounding_boxes
from math import prod
from loguru import logger
import torch
# from torch import tensor as Tensor

Tensor = TypeVar('torch.tensor')

In [2]:
from torch import nn
from torch.utils.data import Dataset
from abc import abstractmethod

class BaseVAE(nn.Module):
    
    def __init__(self) -> None:
        super(BaseVAE, self).__init__()

    def encode(self, input: Tensor) -> List[Tensor]:
        raise NotImplementedError

    def decode(self, input: Tensor) -> Any:
        raise NotImplementedError
    def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
        raise NotImplementedError

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def forward(self, *inputs: Tensor) -> Tensor:
        pass

    @abstractmethod
    def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
        pass

Proposed in_channels:

Almost all UIs will have 3 Buttons and 3 `FIXME` objects. So we will input an ordered list into it of the respective normalised (between 0 and 1) positions.

In [3]:
def get_all_clickable_resources(item, should_be_clickable):
    if item is None:
        return []
    all_boxes = []
    if "bounds" in item.keys() and "resource-id" in item.keys() and "clickable" in item.keys() and item["clickable"]==should_be_clickable:
        all_boxes.append((item["bounds"],item["resource-id"]))
    if "children" in item.keys():
        for child in item["children"]:
            for box in get_all_clickable_resources(child,should_be_clickable):
                all_boxes.append(box)
    return all_boxes

def get_all_bounding_boxes(item, should_be_clickable):
    bboxes = get_all_clickable_resources(item,should_be_clickable)
    reduced_bboxes = []
    already_seen_ids = []
    for box,r_id in bboxes:
        if r_id not in already_seen_ids:
            reduced_bboxes.append(box)
            already_seen_ids.append(r_id)
    return reduced_bboxes


In [4]:
NORMAL_UI_DIMENSIONS = (1440, 2560)

class CustomRicoDataset(Dataset):
    def __init__(self, combined_path="./combined"):
        self.image_files = [
            f
            for f in listdir(combined_path)
            if isfile(join(combined_path, f)) and ("jpg" in f or "jpeg" in f)
        ]
        self.combined_path = combined_path

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = join(self.combined_path,self.image_files[idx])
        json_path = join(self.combined_path,self.image_files[idx].split(".")[0]+".json")
        with open(json_path, "r") as f:
            image_json = json.load(f)
        
        clickable_segments = get_all_bounding_boxes(image_json["activity"]["root"], True)
        reduced_clickable_segments = [
            box
            for box in clickable_segments
            if (prod([box[2]-box[0],box[3]-box[1]]) > 1) and prod([box[2]-box[0],box[3]-box[1]]) < 0.8 * 1440 * 2560
        ]
        normalised_clickable_boxes = []
        for box in reduced_clickable_segments:
            w,h = box[2]-box[0], box[3]-box[1]
            normalised_clickable_boxes.append([w/NORMAL_UI_DIMENSIONS[0],h/NORMAL_UI_DIMENSIONS[1]])
        not_clickable_segments = get_all_bounding_boxes(image_json["activity"]["root"], False)
        not_reduced_clickable_segments = [
            box
            for box in not_clickable_segments
            if (prod([box[2]-box[0],box[3]-box[1]]) > 1) and prod([box[2]-box[0],box[3]-box[1]]) < 0.8 * 1440 * 2560
        ]
        normalised_not_clickable_boxes = []
        for box in not_reduced_clickable_segments:
            w,h = box[2]-box[0], box[3]-box[1]
            normalised_not_clickable_boxes.append([w/NORMAL_UI_DIMENSIONS[0],h/NORMAL_UI_DIMENSIONS[1]])
        

        if len(normalised_clickable_boxes) < 5 or len(normalised_not_clickable_boxes) < 5:
            # logger.warning(f"Ignoring Index {idx} because of too little elements...")
            if idx == (len(self)-1):
                return self[idx-100]
            return self[idx+1]
        # return normalised_clickable_boxes[:5], normalised_not_clickable_boxes[:5]
        return torch.tensor(normalised_clickable_boxes[:5]+normalised_not_clickable_boxes[:5]).unsqueeze(0)


In [5]:
dataset = CustomRicoDataset("../combined")

In [6]:
sample_idx = torch.randint(len(dataset), size=(1,)).item()
j = dataset[sample_idx]

In [7]:
j

tensor([[[1.0000, 0.0766],
         [0.2667, 0.0508],
         [0.6222, 0.0508],
         [0.3333, 0.0656],
         [0.3333, 0.0656],
         [0.9264, 0.0285],
         [0.9264, 0.0016],
         [0.8167, 0.1145],
         [0.9264, 0.0016],
         [0.2431, 0.0070]]])

In [8]:
import torch
# from models import BaseVAE
from torch import nn
from torch.nn import functional as F
# from .types_ import *


class VanillaVAE(BaseVAE):
    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 **kwargs) -> None:
        super(VanillaVAE, self).__init__()

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    #nn.Conv2d(in_channels, out_channels=h_dim,
                    #          kernel_size= 3, stride= 2, padding  = 1),
                    nn.Linear(in_channels, h_dim),
                    # nn.BatchNorm1d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*10, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*10, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 10)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    #nn.ConvTranspose2d(hidden_dims[i],
                    #                   hidden_dims[i + 1],
                    #                   kernel_size=3,
                    #                   stride = 2,
                    #                   padding=1,
                    #                   output_padding=1),
                    #nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.Linear(hidden_dims[i], hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            #nn.ConvTranspose2d(hidden_dims[-1],
                            #                   hidden_dims[-1],
                            #                   kernel_size=3,
                            #                   stride=2,
                            #                   padding=1,
                            #                   output_padding=1),
                            #nn.BatchNorm2d(hidden_dims[-1]),
                            nn.Linear(hidden_dims[-1],hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Linear(hidden_dims[-1],2),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 10, 512)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)


        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [9]:
vae = VanillaVAE(in_channels=2, latent_dim=128)

In [10]:
#j_tensor = torch.tensor(j[0]+j[1]).unsqueeze(0)
recons, inp, mu, log_var = vae.forward(j)

In [11]:
loss = vae.loss_function(recons,inp,mu,log_var, M_N=1)

In [12]:
loss

{'loss': tensor(0.2492, grad_fn=<AddBackward0>),
 'Reconstruction_Loss': tensor(0.1973),
 'KLD': tensor(-0.0520)}

In [13]:
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from accelerate import Accelerator

train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [14]:
accelerator = Accelerator()

In [15]:
optimizer = torch.optim.Adam(vae.parameters(),
                               lr=0.0001)

In [16]:
vae, optimizer, train_dataloader = accelerator.prepare(vae,optimizer, train_dataloader)

In [17]:
vae

VanillaVAE(
  (encoder): Sequential(
    (0): Sequential(
      (0): Linear(in_features=2, out_features=32, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Linear(in_features=64, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
    )
    (3): Sequential(
      (0): Linear(in_features=128, out_features=256, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
    )
    (4): Sequential(
      (0): Linear(in_features=256, out_features=512, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
    )
  )
  (fc_mu): Linear(in_features=5120, out_features=128, bias=True)
  (fc_var): Linear(in_features=5120, out_features=128, bias=True)
  (decoder_input): Linear(in_features=128, out_features=5120, bias=True)
  (decoder): Sequential(
    (0): Sequential(
      (0): Linear(in_features=512, out_features=25

In [None]:
losses = []
for batch in tqdm(train_dataloader):
    # j_tensor = .to(accelerate)
    # print(batch)
    recons, inp, mu, log_var = vae.forward(batch)
    loss = vae.loss_function(recons,inp,mu,log_var, M_N=64)
    # loss["loss"].backward()
    
    accelerator.backward(loss["loss"])

    torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.5)

    
    optimizer.step()
    
    print(loss["loss"])
    losses.append(loss["loss"])

  0%|          | 0/1036 [00:00<?, ?it/s]

  recons_loss =F.mse_loss(recons, input)


tensor(3.5028, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.7489, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.8661, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5760, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4960, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4275, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4616, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.5043, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4600, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4239, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.3891, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.3343, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.3235, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2948, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2702, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2953, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2776, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.3073, device='cuda:0', grad_fn=<AddBack

In [None]:
dataset[0].shape

In [None]:
recons, inp, mu, log_var = vae.forward(dataset[0].to(accelerator.device))

In [None]:
recons

In [None]:
inp