# Conditional VAE

With Input:
- Image Label
- Coordinate

In [1]:
import os
import csv
import numpy as np
import math
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from skimage import io
from PIL import Image
from tqdm import tqdm
#import argparse

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.utils import save_image

matplotlib.style.use('ggplot')

torch.cuda.empty_cache() 
#import model

import ast

## To load in the dataset

In [2]:
class ActiveVisionDataset (Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file, index_col=None)
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self,index):
        if type(index) == torch.Tensor:
            index = index.item()
        img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
        image = io.imread(img_path)
        #image = image/(image.max()/255.0)
        shape_label = torch.tensor(int(self.annotations.iloc[index,1]))
        cam_loc = torch.tensor(ast.literal_eval(self.annotations.iloc[index,2]))
        
        if self.transform:
            image = self.transform(image)
        
        return image, shape_label, cam_loc

## Encoder

In [3]:
class Encoder(nn.Module):
    def __init__(self, z_dim):
        super(Encoder, self).__init__()
        
        self.conv1 = nn.Conv2d(
            in_channels=3, out_channels=init_kernel, kernel_size=kernel_size, stride=stride, padding=padding
        )
        self.conv2 = nn.Conv2d(
            in_channels=init_kernel, out_channels=init_kernel*2, kernel_size=kernel_size, stride=stride, padding=padding
        )
        self.conv3 = nn.Conv2d(
            in_channels=init_kernel*2, out_channels=init_kernel, kernel_size=kernel_size, stride=stride, padding=padding
        )
        
        #self.dropout = nn.Dropout()
        #self.bn_enc1 = nn.BatchNorm2d(init_kernel)
        self.img_lin1 = nn.Linear(init_kernel*(32**2), 256)
        
        self.cond_lin1 = nn.Linear(4, 256)
        
        self.mu = nn.Linear(512, z_dim)
        self.sigma = nn.Linear(512, z_dim)
        
        self.bn_lin1 = nn.BatchNorm1d(1000)
 
    def forward(self, image, label, coord):
        
        #Image
        #print("before anything")
        #print(image.shape)
        x = self.conv1(image)
        x = F.relu(x)
        #print("before flatten:")
        #print(x.shape)
        x = x.view(x.size(0), -1)
        #print("after flatten:")
        #print(x.shape)
        x = self.img_lin1(x)
        x = F.relu(x)
        
        #Label and Coordinate
        label = torch.unsqueeze(label, dim=1)
        y = torch.cat([coord,label],dim=1)
        #print(label) print(label.shape) print(coord) print(coord.shape) print(y) print(y.shape)
        y = self.cond_lin1(y)
        y = F.relu(y)
        #print(y.shape)
        
        x = torch.cat([x,y],dim=1)
        #print(x.shape)
        
        # get `mu` and `log_var`
        mu = self.mu(x)
        log_var = self.sigma(x)
        
        return mu, log_var

## Decoder

In [4]:
class Decoder(nn.Module):
    def __init__(self, z_dim):
        super(Decoder, self).__init__()
        
        self.img_lin1 = nn.Linear(z_dim, 256)
        self.cond_lin1 = nn.Linear(4, 256)
        
        self.lin1 = nn.Linear(512, init_kernel*(32**2))
        
        self.dec1 = nn.ConvTranspose2d(
            in_channels=init_kernel, out_channels=3, kernel_size=kernel_size, stride=stride, padding=padding
        )
        self.dec2 = nn.ConvTranspose2d(
            in_channels=init_kernel*8, out_channels=init_kernel*4, kernel_size=kernel_size, stride=stride, padding=padding
        )
        self.dec3 = nn.ConvTranspose2d(
            in_channels=init_kernel*4, out_channels=3, kernel_size=kernel_size, stride=stride, padding=padding
        )
        
        #self.dropout = nn.Dropout()
        #self.bn_enc1 = nn.BatchNorm2d(init_kernel*2)
        self.bn_lin4 = nn.BatchNorm1d(init_kernel*119*119)
        
    def forward(self, z, label, coord):
        
        x = self.img_lin1(z)
        
        #Label and Coordinate
        label = torch.unsqueeze(label, dim=1)
        y = torch.cat([coord,label],dim=1)
        
        y = self.cond_lin1(y)
        y = F.relu(y)
        #print(y.shape)
        
        x = torch.cat([x,y],dim=1)
        #print(x.shape)
        
        x = self.lin1(x)
        x=F.relu(x)
        
        x=x.view(-1, init_kernel, 32, 32)
        #print("after unflatten:")
        #print(x.shape)
        
        x = self.dec1(x)
        #x = F.relu(x)
        #x = self.dec2(x)
        #x = F.relu(x)
        #x = self.dec3(x)
        reconstruction = torch.sigmoid(x)
        
        return reconstruction

## VAE

In [5]:
class ConditionalVAE(nn.Module):
    def __init__(self, z_dim):
        super(ConditionalVAE, self).__init__()
        self.encoder = Encoder(z_dim)
        self.decoder = Decoder(z_dim)
    
    def forward(self, image, label, coord):
        mu, log_var = self.encoder(image, label, coord)
        
        #print('mu: ', mu.shape)
        #print('log_var: ', log_var.shape)
        
        #sample z from latent distribution q
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu,std)
        z = q.rsample()
        #print('z shape: ', z.shape)
        
        reconstruction = self.decoder(z, label, coord)
                
        return reconstruction, mu, log_var, z

## Loss Helper Functions

In [6]:
def gaussian_likelihood(mean, logscale, sample):
    scale = torch.exp(logscale)
    dist = torch.distributions.Normal(mean, scale)
    log_pxz = dist.log_prob(sample)
    return log_pxz.sum(dim=(1, 2, 3))

def kl_divergence(z, mu, std):
    # --------------------------
    # Monte carlo KL divergence
    # --------------------------
    # 1. define the first two probabilities (in this case Normal for both)
    p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
    q = torch.distributions.Normal(mu, std)

    # 2. get the probabilities from the equation
    log_qzx = q.log_prob(z)
    log_pz = p.log_prob(z)

    # kl
    kl = (log_qzx - log_pz)
    kl = kl.sum(-1)
    return kl

## Training

In [7]:
def fit(model, dataloader):
    model.train()
    torch.set_grad_enabled(True) #???
    running_loss = 0.0
    
    for batch in tqdm(dataloader):
        image, label, coord = batch
        #print(image.size())
        if torch.cuda.is_available():
            image = image.to(device)
            label = label.to(device)
            coord = coord.to(device)
        optimizer.zero_grad()
        
        reconstruction, mu, log_var, z = model(image, label, coord)
        
        #print(reconstruction.shape)
        
        #image = image.to(torch.device('cpu'))
        recon_loss = gaussian_likelihood(reconstruction, log_scale, image)
        
        std = torch.exp(log_var / 2)
        kl = kl_divergence(z, mu, std)

        elbo = (kl - recon_loss)
        elbo = elbo.mean()
        
        elbo.backward()
        optimizer.step()
        
        running_loss += elbo
    
    train_loss = running_loss/len(dataloader.dataset) #Investigate
    return train_loss

## Validation

In [8]:
def validate(model, dataloader):
    model.eval()
    running_loss = 0.0
    i = 1
    with torch.no_grad():
        for batch in tqdm(dataloader):
            
            image, label, coord = batch
            
            if torch.cuda.is_available():
                image = image.to(device)
                label = label.to(device)
                coord = coord.to(device)
                
            reconstruction, mu, log_var, z = model(image, label, coord)
            
            recon_loss = gaussian_likelihood(reconstruction, log_scale, image)
            
            std = torch.exp(log_var / 2)
            kl = kl_divergence(z, mu, std)

            elbo = (kl - recon_loss)
            elbo = elbo.mean()
            
            running_loss += elbo 
            
            #Add functionality to display every x batches of images
            i+=1
    
    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

## Parameters

In [9]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

enc_out_dim = 512
latent_dim = 128
input_height = 100 #Change to not hard coded

epochs = 20
batch_size = 4
learning_rate = 0.001

kernel_size = 4
stride = 4
padding = 0
init_kernel = 8

train_data = ActiveVisionDataset(csv_file='imgs/TrainSet/rgbCSV.csv', root_dir= 'imgs/TrainSet/rgbImg/', transform = torchvision.transforms.ToTensor())
val_data = ActiveVisionDataset(csv_file='imgs/ValSet/rgbCSVNumbered.csv', root_dir= 'imgs/ValSet/segImg/', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True)

img_size = len(train_data[0][0][0])
model = ConditionalVAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

log_scale = nn.Parameter(torch.Tensor([0.0])).to(device)

cuda


## Run

In [10]:
def run():
    train_loss = []
    val_loss = []
    for epoch in range(epochs):

        print(f"Epoch {epoch+1} of {epochs}")
        train_epoch_loss = fit(model, train_loader)
        val_epoch_loss = validate(model, val_loader)

        train_loss.append(train_epoch_loss)
        val_loss.append(val_epoch_loss)

        print(f"Train Loss: {train_epoch_loss:.4f}")
        print(f"Val Loss: {val_epoch_loss:.4f}")

In [11]:
run()

  0%|                                                                                          | 0/800 [00:00<?, ?it/s]

Epoch 1 of 20


100%|████████████████████████████████████████████████████████████████████████████████| 800/800 [00:08<00:00, 94.77it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 246.40it/s]
  3%|██▎                                                                             | 23/800 [00:00<00:06, 112.32it/s]

Train Loss: 11404.6338
Val Loss: nan
Epoch 2 of 20


100%|███████████████████████████████████████████████████████████████████████████████| 800/800 [00:06<00:00, 116.54it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 241.78it/s]
  2%|█▏                                                                              | 12/800 [00:00<00:06, 119.30it/s]

Train Loss: 11360.1074
Val Loss: nan
Epoch 3 of 20


  4%|██▉                                                                             | 29/800 [00:00<00:06, 112.96it/s]


KeyboardInterrupt: 