<a href="https://colab.research.google.com/github/ij264/Corpus-Drawing-Project/blob/master/sketchRNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install ipython-autotime

%load_ext autotime

In [None]:
# imports
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable

In [None]:
# hyperparameters
# UPDATE

hp = {
    'location': '/content/drive/Shared drives/Corpus Drawing Project/data/sketchrnn_airplane.npz',
    'Nz': 128,
    'batch_size': 100,
    'encoder_hidden_size': 256,
    'decoder_hidden_size': 512,
    'temperature': 0.9,
    'gradient_clipping': 1.0,
    'lr': 1e-4,
    'KL_min': 0.2,
    'R': 0.9999,
    'M': 20,
    'WKL': 1.0,
    'dropout': 0.1
}

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# DONE
# returns maximum sequence length in stroke sequences in data
def get_max_length(data):
    sequences = [len(seq) for seq in data]
    return max(sequences)

In [None]:
class DataLoader(object):
    def __init__(self,
                 strokes,
                 batch_size=5,
                 random_scale_factor=0.0,
                 augment_stroke_prob=0.0,
                 limit=1000):
        self.batch_size = batch_size # Batch size.
        self.max_seq_length = get_max_length(strokes) # Nmax.
        self.random_scale_factor = random_scale_factor # Data augmentation method.

        # Removes large gaps in data. x and y offets are clamped to have absolute values no greater than this limit.
        self.limit = limit
        self.augment_stroke_prob = augment_stroke_prob
        self.start_stroke_token = torch.Tensor([0, 0, 1, 0, 0]) # S_0 in the paper.

        # sets self.strokes: list of arrays (sorted by size), one per sketch, in stroke-3 format (DeltaX, DeltaY, pen binary state)
        self.preprocess(strokes)
        self.pad_data(self.strokes, self.max_seq_length)
        self.normalise()
    
    def preprocess(self, strokes): 
        # Removes entries from strokes having a sequence longer than max_seq_lengths
        raw_data = []
        seq_len = []
        count_data = 0
        
        for data in strokes:

            if len(data) <= (self.max_seq_length):
                count_data += 1
                # removes large gaps from the data
                data = np.minimum(data, self.limit)
                data = np.maximum(data, -self.limit)
                raw_data.append(data)
                seq_len.append(len(data))

        seq_len = np.array(seq_len)  # n strokes for each sketch
        idx = np.argsort(seq_len)

        self.strokes = []

        for i in range(len(seq_len)):
            self.strokes.append(raw_data[idx[i]])

        print("total images <= max_seq_len is %d" % count_data)

        self.num_batches = int(count_data / self.batch_size)
        return self.strokes

    def calculate_normalizing_scale_factor(self):
        """Calculate the normalizing factor explained in appendix of sketch-rnn."""
        return torch.std(self.strokes) 

    def normalise(self):
        ''' Normalise entire dataset by normalising factor '''
        scale_factor = self.calculate_normalizing_scale_factor()
        self.strokes[:,:,0:2] /= scale_factor
        return self.strokes

    def pad_data(self, data, max_len):
        ''' Pad the batch to be stroke-5 bigger format as described in paper. '''
        padded_data = np.zeros((len(data), max_len + 1, 5), dtype=float)

        for i in range(len(data)):
            l = len(data[i])
            assert l <= max_len
            padded_data[i, 0:l, 0:2] = data[i][:, 0:2]
            padded_data[i, 0:l, 3] = data[i][:, 2]
            padded_data[i, 0:l, 2] = 1 - padded_data[i, 0:l, 3]
            padded_data[i, l:, 4] = 1
            # put in the first token, as described in sketch-rnn methodology
            padded_data[i, 1:, :] = padded_data[i, :-1, :]
            padded_data[i, 0, :] = 0
            padded_data[i, 0, 2] = self.start_stroke_token[2]  # setting S_0 from paper.
            padded_data[i, 0, 3] = self.start_stroke_token[3]
            padded_data[i, 0, 4] = self.start_stroke_token[4]

        self.strokes = torch.from_numpy(padded_data)
        return self.strokes

In [None]:
data = np.load(hp['location'], encoding='latin1', allow_pickle=True)
train_strokes = data['train']
test_strokes = data['test']
train_set = DataLoader(train_strokes).strokes

batches = list(torch.utils.data.DataLoader(train_set, batch_size=1000, shuffle = True)) # this function creates batches automatically
# preprocessing done!!!!

In [None]:
# UPDATE
# encoder RNN
class EncoderRNN(nn.Module):
    def __init__(self):
        super(EncoderRNN, self).__init__()

        # bidirectional LSTM 
        self.LSTM = nn.LSTM(5, # input vector is 5x1
                            hp['encoder_hidden_size'],
                            bidirectional=True
                            )
        self.dropout = nn.Dropout(hp['dropout'])
        # mu and sigma from LSTM's output
        self.fc_mu = nn.Linear(2*hp['encoder_hidden_size'],
                               hp['Nz'])
        self.fc_sigma = nn.Linear(2*hp['encoder_hidden_size'],
                                  hp['Nz'])
        
        self.train()

    def forward(self, inputs, batch_size, hidden_cell=None):
        if hidden_cell == None:
            # initialise with zeros
            hidden = torch.zeros(2, batch_size, hp['encoder_hidden_size'])
            cell = torch.zeros(2, batch_size, hp['encoder_hidden_size'])
            hidden_cell = (hidden, cell)

        _, (hidden, cell) = self.LSTM(inputs.float(), hidden_cell) # returns hidden state and cell vector. we discard the output tensor
        hidden_forward, hidden_backward = torch.split(self.dropout(hidden), 1, 0) # returns forward and backwards
        hidden_concat = torch.cat([hidden_forward.squeeze(0), hidden_backward.squeeze(0)], 1) # concatenates the forwad and backwards h 

        mu = self.fc_mu(hidden_concat)
        sigma_hat = self.fc_sigma(hidden_concat)
        sigma_hat = self.sigma(h)
        sigma = torch.exp(sigma_hat/2)

        N = torch.normal(torch.zeros(mu.size()), torch.ones(mu_size())) 

        z = mu + N * sigma
        
        return z, mu, sigma_hat

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self):
        super(DecoderRNN, self).__init__()

        self.fc_hc = nn.Linear(2 * hp['Nz'],
                               hp['decoder_hidden_size'])
        
        self.LSTM = nn.LSTM(hp['Nz'] + 5, # input of decoder is output of encoder (latent vector of size Nz) as well as the previous data point, S_{i-1}
                            hp['decoder_hidden_size'])
        self.dropout = nn.Dropout(hp['dropout'])
        # output vector, y 
        y = nn.Linear() 


In [None]:
# NN model: a bidirectional NN with LSTM
class Model():
    def __init__(self):

        # forward encoder
        self.encoder = EncoderRNN()

        # backward encoder
        self.decoder = DecoderRNN()

        # TODO: implement gradient clipping
        self.encoder_optimiser = optim.Adam(self.encoder.parameters(), hp['lr'])
        self.decoder_optimiser = optim.Adam(self.decoder.parameters(), hp['lr'])

    def something
    # bivariate normal distribution probability distribution function
    '''def bivariate_normal_PDF(Dx, Dy):

        z = (Dx - self.mu_x)**2/self.sigma_x**2 \
        - 2 * self.rho_xy * (Dx - self.mu_x) * (Dy - self.mu_y)/(self.sigma_x * self.sigma_y) \
        + (Dy - self.mu_y)**2/self.sigma_y**2
        prefactor = 1/(2 * np.pi * self.sigma_x * self.sigma_y * torch.sqrt(1 - self.rho_xy**2))

        return prefactor * torch.exp(-z/(2 * (1 - self.rho_xy**2)))
       '''
       #What we want is to be able to have a given Dx and Dy offset and then compute the probabilities of having these offsets given a mean and a standard deviation. 
        mu_x=0
        mu_y=0
        sigma_x=1
        sigma_y=1
        M=torch.distributions.multivariate_normal.MultivariateNormal(torch.tensor([mu_x,mu_y]).float(), \
         covariance_matrix=torch.tensor(([sigma_x,0],[0,sigma_y])).float())  


    # reconstruction loss
    def LR(self, Dx, Dy, p):
        #PDF = bivariate_normal_PDF(Dx, Dy) 
        PDF = torch.exp(M.log_prob(torch.tensor([Dx,Dy]))).item() #If you have any problems, might be the .item() on the end of this line.
        LS = -1/float(N_max) * torch.sum(
            torch.log(
                torch.sum(self.Pi * PDF)
            )
        )
        LP = -1/float(N_max) * torch.sum(
            p * torch.log(self.q)  #Do we need to sum over this twice for our two indicies, p and q
        )
        return LS + LP

    # KL divergence loss 
    # use pytorch function for this (James?)
    def KL(self, Dx, Dy):
        return -1/(2 * float(hp['Nz'])) * torch.sum(1 + sigma_hat - torch.square(mu)-torch.exp(sigma_hat))

# EQ. 7 USE SOFTMAX IN PYTORCH
    #q = nn.Softmax(self.q_hat) 

    def train(self, epoch):
        self.encoder.train()
        self.decoder.train()

        batch, lengths = get_batch(hp['batch_size'])
        
        z, self.mu, self.sigma_hat = self.encoder(batch, hp['batch_size'])
        # `clip_grad_norm` helps prevent the exploding gradient problem 
        # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        S_0 = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * hp['batch_size']).unsqueeze(0)
        initial_batch = torch.cat([S_0, batch], 0)
        stacked_z = torch.stack([z] * (Nmax + 1))
        inputs = torch.cat([inital_batch, z_stack], 2)
        
        # at each step in time, the current data point as well as z is inputted into the decoder.
        self.pi, self.mu_x, self.mu_y, self.sigma_x, self.sigma_y, self.rho_xy, self.q, _, _ = self.decoder(inputs, z)

# Coding playground

In [None]:
Random = torch.randn(20,5,10,10)
m=nn.LayerNorm(Random.size()[1:])
m(Random)
#Layer normalization

In [None]:
Func=nn.KLDivLoss()
k=torch.randn(3,3)
l=torch.randn(3,3)
print(Func(k,k)) #Why is the KLloss non-zero for identical outcomes?

In [None]:
torch.nn.functional.kl_div(k,k) #Why isn't this zero, shouldn't distributions be identical