In [None]:
# Imports needed
import numpy as np
np.random.seed(0)
import torch
torch.manual_seed(0)
import random
random.seed(0)
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.nn import Module
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms import transforms
from torch.autograd import Variable

from torchvision.utils import save_image
import cv2
import os
import torchvision
import matplotlib.pyplot as plt

In [None]:
"""
Data needed:
ImageEncoder Input: [PNG Image -> Greyscale], shape = [28x28]
             Output: Latent vector z, shape = [1,128]

SketchDecoderRNN Input: Latent z
                 Output: Stroke sequence
                 
Training:
    -> Compare generated z and EncoderRNN[Corresponding seq to PNG Image]
    -> 3 losses

"""

# For predictive model encoder optimization
# input Photo path, ordered to match
train_photo_image_path = "../Datasets/cat/train"
test_photo_image_path = "../Datasets/cat/test"
valid_photo_image_path = "../Datasets/cat/valid"

# paired input, sequence arrays
seq_data = "../Datasets/cat.npz"

# set epochs, similar to the training loop needed for sketchRNN
epochs = 10000
image_size = 28 # 28 x 28 size

Nmax = 200 #largest stroke length
eta_min = 0 #lkl loss
R = 0 #lkl loss

# load the dataset, convert images and vectors to tensors
image_sequence_dataset = LoadData(X, input_sequence_path)

# transform dataset to torch representation
torch_train_ImgSeq = DataLoader(image_sequence_dataset, shuffle=True, batch_size=1, num_workers=0)

In [None]:
# This class deals with dataset loading
class Cat(Dataset):
    # initialise paths to the X = image, y = paired strokes
    # function to give the length of dataset
    # transform the images to tensors 
    # transform the stroke data to tensors
    # resize the images so they're the correct size

In [None]:
# Encoder that takes in the latent vector z, and the stroke data

# this class takes in the image and outputs a vector z
class ImageEncoder(Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        input_shape = image_size**2*2 # 0 or 1, black and white
        self.enc_h1 = nn.Linear(in_features=input_shape, out_features=128)
        self.enc_out = nn.Linear(in_features=128, out_features=bottleneck_size)
        
    def forward(self, image):
        h1_ac = torch.relu(self.enc_h1(image))
        h2_out = self.enc_out(h1_ac)
        z = torch.sigmoid(h2_out)
        return z
        
# this class takes in vector z and outputs a sequence for the sketch
class StrokeDecoder(Module):
    def __init__(self):
        super(Decoder, self).__init__()
        # define sketchRNN decoder architecture
        self.fc_hc = nn.Linear(hp.Nz, 2*hp.dec_hidden_size)
        self.lstm = nn.LSTM(hp.Nz+5, hp.dec_hidden_size, dropout=hp.dropout)
        self.fc_params = nn.Linear(hp.dec_hidden_size,6*hp.M+3)
        
    def forward(self, x2_inputs, z_vector):
        # pass z_vector to decoder layers and produce stroke data
        # similar to the conditional generation model
        if self.training:
            y = self.fc_params(outputs.view(-1, hp.dec_hidden_size))
        else:
            y = self.fc_params(hidden.view(-1, hp.dec_hidden_size))
        # separate pen and mixture params:
        params = torch.split(y,6,1)
        params_mixture = torch.stack(params[:-1]) # trajectory
        params_pen = params[-1] # pen up/down
        # identify mixture params:
        pi,mu_x,mu_y,sigma_x,sigma_y,rho_xy = torch.split(params_mixture,1,2)
        # preprocess params::
        if self.training:
            len_out = Nmax+1
        else:
            len_out = 1
                                   
        pi = F.softmax(pi.transpose(0,1).squeeze()).view(len_out,-1,hp.M)
        sigma_x = torch.exp(sigma_x.transpose(0,1).squeeze()).view(len_out,-1,hp.M)
        sigma_y = torch.exp(sigma_y.transpose(0,1).squeeze()).view(len_out,-1,hp.M)
        rho_xy = torch.tanh(rho_xy.transpose(0,1).squeeze()).view(len_out,-1,hp.M)
        mu_x = mu_x.transpose(0,1).squeeze().contiguous().view(len_out,-1,hp.M)
        mu_y = mu_y.transpose(0,1).squeeze().contiguous().view(len_out,-1,hp.M)
        q = F.softmax(params_pen).view(len_out,-1,3)
        return pi,mu_x,mu_y,sigma_x,sigma_y,rho_xy,q,hidden,cell

In [None]:
class PhotoToSketch():
    def __init__(self):
        if use_cuda:
            self.encoder = ImageEncoder().cuda()
            self.decoder = StrokeDecoder().cuda()
        else:
            self.encoder = ImageEncoder()
            self.decoder = StrokeDecoder()
        self.encoder_optimizer = optim.Adam([{'params': encoder.parameters(), 'lr': 1e-3}, {'params': decoder.parameters(), 'lr': 1e-3}], lr=1e-3)
        self.decoder_optimizer = optim.Adam([{'params': encoder.parameters(), 'lr': 1e-3}, {'params': decoder.parameters(), 'lr': 1e-3}], lr=1e-3)
        self.eta_step = eta_min

    def train(self, epoch):
        for i,data in enumerate(torch_train_ImgSeq):
            # Training loop to update weights for Photo To Sketch model
            self.encoder.train()
            self.decoder.train()
            X2, lengths = extract_lengths()

            z = self.encoder(torch_train_ImgSeq)
            if use_cuda:
                sos = torch.stack([torch.Tensor([0,0,1,0,0])]*num_of_images).cuda().unsqueeze(0)
            else:
                sos = torch.stack([torch.Tensor([0,0,1,0,0])]*num_of_images).unsqueeze(0)
            init = torch.cat([sos, X2], 0)
            z_stack = torch.stack([z]*(Nmax+1))
            inputs = torch.cat([init, z_stack],2)
            self.pi, self.mu_x, self.mu_y, self.sigma_x, self.sigma_y, \
                self.rho_xy, self.q, _, _ = self.decoder(inputs, z)
            mask,dx,dy,p = self.make_target(X2, lengths)

            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()

            self.eta_step = 1-(1-eta_min)*R

            # compute all losses: EncoderLoss, LKL, reconstructionLoss
            EL = loss_function(z, target_z) 
            KLL = self.kullback_leibler_loss()
            RL = self.reconstruction_loss(mask,dx,dy,p,epoch)
            loss = RL + KLL + EL # may need to include weight here for EL
            
            loss.backward()
            nn.utils.clip_grad_norm(self.encoder.parameters(), hp.grad_clip)
            nn.utils.clip_grad_norm(self.decoder.parameters(), hp.grad_clip)
            self.encoder_optimizer.step()
            self.decoder_optimizer.step()
            if epoch%1==0:
    #             print('epoch',epoch,'loss',loss.data[0],'LR',LR.data[0],'LKL',LKL.data[0])
                self.encoder_optimizer = lr_decay(self.encoder_optimizer)
                self.decoder_optimizer = lr_decay(self.decoder_optimizer)
            if epoch%200==0:
                self.save(epoch)
                self.conditional_generation(epoch)

#         train_loss = []
#         for epoch in range(epochs):
#             running_loss = 0.0
#             for i,data in enumerate(torch_train_ImgSeq):
#                 photo = data[0]
#                 target_z = data[1]
#                 optimizer.zero_grad()
#                 latent_out = encode(photo)
#                 # calculate the loss dependent on the output z and target z
#                 loss = loss_function(latent_out, target_z)
#                 loss.backward()
#                 optimizer.step()
#                 running_loss += loss.item()
#                 outputs = decode(latent_out) # only do the decoder part when the optimal encoder has been found

#             loss = running_loss / len(torch_train_ImgSeq)
#             train_loss.append(loss)
#             print(f'Epoch {ep+1} of {epochs}, Train Loss: {loss:.5f}')

In [None]:
# Helper functions
def sample_bivariate_normal(mu_x,mu_y,sigma_x,sigma_y,rho_xy, greedy=False):
    # inputs must be floats
    if greedy:
        return mu_x,mu_y
    mean = [mu_x, mu_y]
    sigma_x *= np.sqrt(hp.temperature)
    sigma_y *= np.sqrt(hp.temperature)
    cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y],\
        [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]
    x = np.random.multivariate_normal(mean, cov, 1)
    return x[0][0], x[0][1]

# loss function to compare produced_z with given z
def loss_function(predicted_z, target_z):
    # uses mean square loss
    criterion = nn.MSELoss()
    return criterion(predicted_z, target_z)

In [None]:
# Test model

def test(photos):
    plt.figure(figsize=(10,10))
    idx = 0
    for x, y in photos:
        show_image(x, idx+1, "Before sketching:")
        make_image(y, idx+2, "Target sketch:")
        out = encode(x)
        out = decode(out)
        make_image(out, idx+3, "Predicted sketch:")
        idx += 3
        
def show_image(image, idx, label):
    img = image.squeeze().permute(1, 2, 0).detach().numpy()
    plt.subplot(3, 3, idx)
    plt.title(label)
    plt.imshow(img)
    
def make_image(sequence, idx, label):
    
    img_path = f"{hp.test_photo_image_path}{i}.png"
    img = mpimg.imread(img_path)
    imgplot = plt.imshow(img)
    plt.show()
    
#edi this function
# def make_image(sequence, epoch, name='_output_'):
#     """plot drawing with separated strokes"""
#     strokes = np.split(sequence, np.where(sequence[:,2]>0)[0]+1)
#     fig = plt.figure()
#     ax1 = fig.add_subplot(111)
#     for s in strokes:
#         plt.plot(s[:,0],-s[:,1])
#     print("Outputting sketch")
#     plt.show()