<a href="https://colab.research.google.com/github/ij264/Corpus-Drawing-Project/blob/master/sketchRNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install ipython-autotime

%load_ext autotime

In [None]:
# imports
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable

In [None]:
# hyperparameters
# UPDATE

hp = {
    'location': '/content/drive/Shared drives/Corpus Drawing Project/data/sketchrnn_airplane.npz',
    'Nz': 128,
    'batch_size': 1,
    'encoder_hidden_size': 256,
    'decoder_hidden_size': 512,
    'temperature': 0.9,
    'gradient_clipping': 1.0,
    'lr': 1e-4,
    'KL_min': 0.2,
    'R': 0.9999,
    'M': 20,
    'WKL': 1.0,
    'dropout': 0.1
}

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# DONE
# returns maximum sequence length in stroke sequences in data
def max_size(data):
    sequences = [len(seq) for seq in data]
    return max(sequences)

In [None]:
data = np.load(hp['location'], encoding='latin1', allow_pickle=True)
training_data = data['train']
testing_data = data['test']

Nmax = max_size(training_data)

In [None]:
# returns batches of size batch_size
def get_batch(data, batch_size):
    batch_idx = np.random.choice(len(data), batch_size) # creates array of random indices of length batch_size
    batch_sequence = [data[idx] for idx in batch_idx]
    strokes = []
    lengths = []

    for sequence in batch_sequence:
        sequence_len = len(sequence[:, 0]) # length of first column
        new_sequence = np.zeros((Nmax, 5)) # initalises empty sequence to store strokes. each row is in the form DeltaX, DeltaY, p1, p2, p3
        new_sequence[:sequence_len, :2] = sequence[:, :2] # initalises DeltaX, DeltaY
        new_sequence[:sequence_len - 1, 2] = 1-sequence[:-1, 2] # initalises p1
        new_sequence[:sequence_len, 3] = sequence[:, 2] # initialises p2
        new_sequence[(sequence_len - 1):, 4] = 1 # initialises p3
        new_sequence[sequence_len - 1, 2:4] = 0
        lengths.append(len(sequence[:, 0]))
        strokes.append(new_sequence)
    
    batch = Variable(torch.from_numpy(np.stack(strokes, 1)).float())
    
    return batch, lengths

In [None]:
# UPDATE
# encoder RNN
class EncoderRNN(nn.Module):
    def __init__(self):
        super(EncoderRNN, self).__init__()

        # bidirectional LSTM 
        self.LSTM = nn.LSTM(input_size=5, # input vector is 5x1
                            hp['encoder_hidden_size'],
                            hp['decoder_hidden_size'],
                            num_layers=1,
                            bias=True,
                            batch_first=False,
                            bidirectional=True
                            )
        self.dropout = nn.Dropout(hp['dropout'])
        # mu and sigma from LSTM's output
        self.fc_mu = nn.Linear(in_features=2*hp['encoder_hidden_size'],
                            hp.Nz)
        self.fc_sigma = nn.Linear(in_features=2*hp['encoder_hidden_size'],
                               hp['Nz'])
        
        self.train()

    def forward(self, inputs, batch_size, hidden_cell=None):
        if hidden_cell == None:
            # initialise with zeros
            hidden = torch.zeros(2, batch_size, hp['encoder_hidden_size'])
            cell = torch.zeros(2, batch_size, hp['encoder_hidden_size'])
            hidden_cell = (hidden, cell)

        _, (hidden, cell) = self.LSTM(inputs.float(), hidden_cell) # returns hidden state and cell vector. we discard the output tensor
        hidden_forward, hidden_backward = torch.split(self.dropout(hidden), 1, 0) # returns forward and backwards
        hidden_concat = torch.cat([hidden_forward.squeeze(0), hidden_backward.squeeze(0)], 1) # concatenates the forwad and backwards h 

        mu = self.fc_mu(hidden_concat)
        sigma_hat = self.fc_sigma(hidden_concat)
        sigma_hat = self.sigma(h)
        sigma = torch.exp(sigma_hat/2)

        N = torch.normal(torch.zeros(mu.size()), torch.ones(mu_size()))

        z = mu + N * sigma
        
        return z, mu, sigma_hat

In [None]:
class DecoderRNN(nn.Module):
    # stuff

In [None]:
# NN model: a bidirectional NN with LSTM
class Model():
    def __init__(self):

        # forward encoder
        self.encoder() = EncoderRNN()

        # backward encoder
        self.decoder() = DecoderRNN()

        # TODO: implement gradient clipping
        self.encoder_optimiser() = optim.Adam(self.encoder.parameters(), hp['lr'])
        self.decoder_optimiser() = optim.Adam(self.decoder.parameters(), hp['lr'])

    # bivariate normal distribution probability distribution function
    def bivariate_normal_PDF(Dx, Dy):

        z = (Dx - self.mu_x)**2/self.sigma_x**2 \
        - 2 * self.rho_xy * (Dx - self.mu_x) * (Dy - self.mu_y)/(self.sigma_x * self.sigma_y) \
        + (Dy - self.mu_y)**2/self.sigma_y**2
        prefactor = 1/(2 * np.pi * self.sigma_x * self.sigma_y * torch.sqrt(1 - self.rho_xy**2))

        return prefactor * torch.exp(-z/(2 * (1 - self.rho_xy**2)))
       
       #This is using the inbuilt function from Pytorch, if you want to generate a random vector from it you use M.sample() and it'll create a vector with two elements. 
        '''
        mx=0
        my=0
        sx=1
        sy=1
        M=torch.distributions.multivariate_normal.MultivariateNormal(torch.tensor([mx,my]).float(), \
         covariance_matrix=torch.tensor(([sx,0],[0,sy])).float())
        '''

    # reconstruction loss
    def LR(self, Dx, Dy, p):
        PDF = bivariate_normal_PDF(Dx, Dy)
        LS = -1/float(N_max) * torch.sum(
            torch.log(
                torch.sum(self.Pi * PDF)
            )
        )
        LP = -1/float(N_max) * torch.sum(
            p * torch.log(self.q)
        )
        return LS + LP

    # KL divergence loss 
    def KL(self, Dx, Dy):
        return -1/(2 * float(hp['Nz'])) * torch.sum(1 + sigma_hat - torch.square(mu))

# EQ. 7 USE SOFTMAX IN PYTORCH
    q = nn.Softmax(self.q_hat) 

    def train(self, epoch):
        self.encoder.train()
        self.decoder.train()

        batch, lengths = get_batch(hp['batch_size'])
        
        total_loss = 0
        for i, (data, _) in enumerate(training_set):

        # `clip_grad_norm` helps prevent the exploding gradient problem 
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)