In [1]:
import numpy as np
from utils import *

In [2]:
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from dataset import *
from torch.utils.data import DataLoader
import torch.distributions as dist
from torch import optim

In [3]:
dataset = np.load("sketchrnn-cat.full.npz", encoding='latin1')
data = dataset['test']
data = purify(data)
data = normalize(data)
Nmax = max_size(data)

In [4]:
class Encoder(nn.Module):
    def __init__(self,
                 hidden_size=256,
                 z_size=128,
                 dropout=0.9):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.z_size = z_size
        self.rnn = nn.GRU(5, self.hidden_size, dropout=dropout, bidirectional=True, batch_first=True)
        self.fc_mu = nn.Linear(2*self.hidden_size, self.z_size)
        self.fc_logvar = nn.Linear(2*self.hidden_size, self.z_size)
        
    def forward(self,inputs):
        _, hidden = self.rnn(inputs)
        hidden_cat = torch.cat(hidden.split(1,0),2).squeeze(0)
        mu = self.fc_mu(hidden_cat)
        logvar = self.fc_logvar(hidden_cat)
        return mu,logvar

In [5]:
class Decoder(nn.Module):
    def __init__(self,
                 hidden_size=256,
                 z_size=128,
                 dropout=0.9):
        super(Decoder, self).__init__()
        self.rnn = nn.GRU(z_size+5, hidden_size, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(z_size,hidden_size)
        
    def init_state(self,z):
        return F.tanh(self.fc(z).unsqueeze(0))
    
    def forward(self,inputs,z):
        hidden = self.init_state(z)
        zs = z.unsqueeze(1).expand(-1,inputs.size(1),-1)
        inputs_z = torch.cat([inputs,zs],dim=2)
        return self.rnn(inputs_z,hidden)

In [6]:
class SketchRNN(nn.Module):
    def __init__(self,
                 enc_hidden_size=256,
                 dec_hidden_size=256,
                 z_size=128,
                 dropout=0.9,
                 M=20):
        super(SketchRNN, self).__init__()
        self.dec_hidden_size = dec_hidden_size
        self.M = M
        self.encoder = Encoder(enc_hidden_size,z_size,dropout)
        self.decoder = Decoder(dec_hidden_size,z_size,dropout)
        self.gmm = nn.Linear(dec_hidden_size,6*M+3)
        
    def forward(self,inputs):
        enc_inputs = inputs[:,1:,:]
        mu,logvar = self.encoder(enc_inputs)
        z = self.reparameterize(mu,logvar)
        
        dec_inputs = inputs[:,:-1,:]
        outputs,_ = self.decoder(dec_inputs,z)
        outputs = self.gmm(outputs.contiguous().view(-1,self.dec_hidden_size))
        o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits = self.get_mixture_coef(outputs)
        
        return z,o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits
    
    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu
        
    def get_mixture_coef(self,outputs):
        z_pen_logits = outputs[:, 0:3]  # pen states
        z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = torch.split(outputs[:, 3:], self.M, 1)
        # process output z's into MDN paramters

        # softmax all the pi's and pen states:
        z_pi = F.softmax(z_pi,1)
        z_pen = F.softmax(z_pen_logits,1)

        # exponentiate the sigmas and also make corr between -1 and 1.
        z_sigma1 = torch.exp(z_sigma1)
        z_sigma2 = torch.exp(z_sigma2)
        z_corr = F.tanh(z_corr)

        r = [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen, z_pen_logits]
        return r
    
    def get_lossfunc(self,z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr,
                     z_pen_logits, x1_data, x2_data, pen_data):
        """Returns a loss fn based on eq #26 of http://arxiv.org/abs/1308.0850."""
        # This represents the L_R only (i.e. does not include the KL loss term).
        
        def normal_2d(x1, x2, mu1, mu2, s1, s2, rho):
            """Returns result of eq # 24 of http://arxiv.org/abs/1308.0850."""
            norm1 = x1 - mu1
            norm2 = x2 - mu2
            s1s2 = s1*s2
            # eq 25
            z = norm1.div(s1).pow(2) + norm2.div(s2).pow(2) - 2*(rho*norm1*norm2).div(s1s2)
            neg_rho = 1 - rho.pow(2)
            result = (-z).div(2*neg_rho).exp()
            denom = 2 * np.pi * s1s2 * torch.sqrt(neg_rho)
            result = result.div(denom)
            return result
        
        result0 = normal_2d(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2,z_corr)
        epsilon = 1e-6
        # result1 is the loss wrt pen offset (L_s in equation 9 of
        # https://arxiv.org/pdf/1704.03477.pdf)
        result1 = result0*z_pi
        result1 = torch.sum(result1, 1, keepdim=True)
        result1 = -torch.log(result1 + epsilon)  # avoid log(0)

        fs = 1.0 - pen_data[:, 2]  # use training data for this
        fs = fs.view(-1, 1)
      # Zero out loss terms beyond N_s, the last actual stroke
        result1 = result1*fs
        
        _,labels = pen_data.max(1)
        result2 = F.cross_entropy(z_pen_logits,labels,reduce=False)
        if not self.training:
            result2 = result2 * fs
        
        result = result1 + result2.unsqueeze(1)
        return result.mean()
    
    def KLD(self,mu,logvar):
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    def generate(self,seq_len,z):
        sample = Variable(torch.Tensor([0,0,1,0,0]),requires_grad=False).view(1,1,5)
        hidden = self.decoder.init_state(z)
        
        def bivariate_normal(m1,m2,s1,s2,rho,temp=1.0):
            mean = [m1, m2]
            s1 *= temp * temp
            s2 *= temp * temp
            cov = [[s1 * s1, rho * s1 * s2], [rho * s1 * s2, s2 * s2]]
            x = np.random.multivariate_normal(mean, cov, 1,check_valid='ignore')
            return x[0][0], x[0][1]
        
        for _ in range(seq_len):
            inputs = sample[:,-1:,:]
            inputs_z = torch.cat([inputs,z.unsqueeze(1)],dim=2)
            outputs,hidden = self.decoder.rnn(inputs_z,hidden)
            outputs = self.gmm(outputs.contiguous().view(-1,self.dec_hidden_size))
            o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits = self.get_mixture_coef(outputs)
            
            state = Variable(torch.zeros(1,1,5),requires_grad=False)
            pen_status = dist.Categorical(o_pen.view(-1)).sample()
            state[:,:,pen_status.data[0]+2] = 1
            
            M_id = dist.Categorical(o_pi.view(-1)).sample().data[0]
            mu_1 = o_mu1[0,M_id].data
            mu_2 = o_mu2[0,M_id].data
            sigma_1 = o_sigma1[0,M_id].data
            sigma_2 = o_sigma2[0,M_id].data
            o_corr = o_corr[0,M_id].data
           
            x,y = bivariate_normal(mu_1,mu_2,sigma_1,sigma_2,o_corr)
            state[:,:,0] = x
            state[:,:,1] = y
            
            sample = torch.cat([sample,state],dim=1)
        return sample

In [7]:
sketch = Sketch(data)
training_data = DataLoader(sketch ,batch_size=100, shuffle=True,collate_fn=collate_fn_)

In [8]:
model = SketchRNN()
train_step = optim.Adam(model.parameters(),lr=1e-3)

In [9]:
for epcoch in range(10):
    for data in training_data:
        data = Variable(data)
        z, o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits = model(data)
        
        target = data[:,1:,:]
        target = target.contiguous().view(-1,5)
        x1_data, x2_data, eos_data, eoc_data, cont_data = target.split(1,1)
        pen_data = torch.cat([eos_data, eoc_data, cont_data], 1)
        rec_loss = model.get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, 
                                      o_corr,o_pen_logits, x1_data, x2_data, pen_data)
        train_step.zero_grad()
        rec_loss.backward()
        train_step.step()
        print rec_loss.data[0]

2.57416605949
2.41275024414
2.38926815987
2.32114386559
2.33202314377
2.28989601135
2.16976857185
2.0873708725
2.07914400101
2.11441731453
1.90477955341
1.89647293091
1.87478649616
1.84154009819
1.70401978493
1.65748381615


KeyboardInterrupt: 

In [None]:
s = model.generate(100,z[:1])

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
d1 = next(iter(sketch))

In [None]:
d2 = collate_fn_([d1])[0]

In [None]:
x = np.cumsum(d1.numpy()[:,0])
y = np.cumsum(d1.numpy()[:,1])

In [None]:
for i in range(1,d1.size(0)):
    if not d1.numpy()[i,2]:
        plt.plot(x[i:i+2],-y[i:i+2])

In [None]:
x = np.cumsum(d2.numpy()[:,0])
y = np.cumsum(d2.numpy()[:,1])

for i in range(1,d2.size(0)):
    if d2.numpy()[i,2]:
        plt.plot(x[i:i+2],-y[i:i+2])

In [None]:
d