In [None]:
from google.colab import drive

import numpy as np
import matplotlib.pyplot as plt

import math
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Variable

drive.mount('/content/gdrive')

In [None]:
data_path = './gdrive/MyDrive/Pytorch_SketchRNN/Data/'
batch_size = 100
dim_z = 128
input_size = 5 # vectorized data: (del x, del y, p1, p2, p3)
num_mix_components = 20
dropout = 0.9
eta_min = 0.01
temperature = 0.4
encoder_hidden_size = 256
decoder_hidden_size = 512
max_seq_length = 200
grad_clip = 1.
longest_seq_len = 0
device = ("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class StrokesDataset(Dataset):
  def __init__(self, dataset, max_seq_length, scale = None):
    global longest_seq_len
    
    data = []

    for seq in dataset:
      if 10 < len(seq) <= max_seq_length:
        seq = np.minimum(seq, 1000)
        seq = np.maximum(seq, -1000)

        seq = np.array(seq, dtype=np.float32)
        data.append(seq)
    
    if scale is None:
      scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
    self.scale = scale

    longest_seq_len = max([len(seq) for seq in data])
    self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)
    self.mask = torch.zeros(len(data), longest_seq_len + 1)
    
    for i, seq in enumerate(data):
      seq = torch.from_numpy(seq)
      len_seq = len(seq)

      self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale
      
      self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]
      self.data[i, 1:len_seq + 1, 3] = seq[:, 2]
      self.data[i, len_seq + 1:, 4] = 1
      self.mask[i, :len_seq + 1] = 1
    
    self.data[:, 0, 2] = 1

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    return self.data[idx], self.mask[idx]

In [None]:
def max_size(data):
    """larger sequence length in the data set"""
    sizes = [len(seq) for seq in data]
    return max(sizes)

In [None]:
def purify(strokes):
    """removes to small or too long sequences + removes large gaps"""
    data = []
    for seq in strokes:
        if seq.shape[0] <= max_seq_length and seq.shape[0] > 10:
            seq = np.minimum(seq, 1000)
            seq = np.maximum(seq, -1000)
            seq = np.array(seq, dtype=np.float32)
            data.append(seq)
    return data

In [None]:
def calculate_normalizing_scale_factor(strokes):
    """Calculate the normalizing factor explained in appendix of sketch-rnn."""
    data = []
    for i in range(len(strokes)):
        for j in range(len(strokes[i])):
            data.append(strokes[i][j, 0])
            data.append(strokes[i][j, 1])
    data = np.array(data)
    return np.std(data)

In [None]:
def normalize(strokes):
    """Normalize entire dataset (delta_x, delta_y) by the scaling factor."""
    data = []
    scale_factor = calculate_normalizing_scale_factor(strokes)
    for seq in strokes:
        seq[:, 0:2] /= scale_factor
        data.append(seq)
    return data

In [None]:
dataset = np.load(data_path+'cat.npz', encoding='latin1', allow_pickle=True)
data = dataset['train']
data = purify(data)
data = normalize(data)
Nmax = max_size(data)

In [None]:
def make_batch(batch_size):
    # batch_idx = np.random.choice(len(data),batch_size)
    batch_idx = np.arange(0,101)
    batch_sequences = [data[idx] for idx in batch_idx]
    strokes = []
    lengths = []
    indice = 0
    for seq in batch_sequences:
        len_seq = len(seq[:,0])
        new_seq = np.zeros((Nmax,5))
        new_seq[:len_seq,:2] = seq[:,:2]
        new_seq[:len_seq-1,2] = 1-seq[:-1,2]
        new_seq[:len_seq,3] = seq[:,2]
        new_seq[(len_seq-1):,4] = 1
        new_seq[len_seq-1,2:4] = 0
        lengths.append(len(seq[:,0]))
        strokes.append(new_seq)
        indice += 1
        # if indice < 2: print(new_seq)
    batch = Variable(torch.from_numpy(np.stack(strokes,1)).to(device).float())
    return batch, lengths

In [None]:
class BivariateGaussianMixture:
  def __init__(self, pi_cat_probs, mu_x, mu_y, sig_x, sig_y, rho_xy):
    self.pi_cat_probs = pi_cat_probs
    self.mu_x = mu_x
    self.mu_y = mu_y
    self.sig_x = sig_x
    self.sig_y = sig_y
    self.rho_xy = rho_xy

  def set_temperature(self, temperature):
    self.pi_cat_probs /= temperature
    self.sig_x *= math.sqrt(temperature)
    self.sig_y *= math.sqrt(temperature)

  def get_distribution(self):
    
    sig_x = torch.clamp_min(self.sig_x, 1e-5)
    sig_y = torch.clamp_min(self.sig_y, 1e-5)
    rho_xy = torch.clamp(self.rho_xy, 1e-5 - 1, 1 - 1e-5)

    mean = torch.stack([self.mu_x, self.mu_y], -1)

    cov = torch.stack([
            sig_x * sig_x, rho_xy * sig_x * sig_y,
            rho_xy * sig_x * sig_y, sig_y * sig_y
          ], -1)
    cov = cov.view(*sig_y.shape, 2, 2)

    bi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
    cat_dist = torch.distributions.Categorical(logits=self.pi_cat_probs)

    return bi_dist, cat_dist

  def bivariate_normal_pdf(self, dx, dy):
    z_x = ((dx-self.mu_x)/self.sig_x)**2
    z_y = ((dy-self.mu_y)/self.sig_y)**2
    z_xy = (dx-self.mu_x)*(dy-self.mu_y)/(self.sig_x*self.sig_y)
    z = z_x + z_y -2*self.rho_xy*z_xy
    exp = torch.exp(-z/(2*(1-self.rho_xy**2)))
    norm = 2*np.pi*self.sig_x*self.sig_y*torch.sqrt(1-self.rho_xy**2)
    return exp/norm

In [None]:
class EncoderRNN(nn.Module):
  def __init__(self):
    super(EncoderRNN, self).__init__()

    self.rnn_lstm = nn.LSTM(input_size, encoder_hidden_size, dropout=dropout, bidirectional=True)
    self.mean_linear = nn.Linear(2 * encoder_hidden_size, dim_z)
    self.sigma_linear = nn.Linear(2 * encoder_hidden_size, dim_z)
    # self.train()
    
  def forward(self, input, batch_size, state=None):
    if state is None:
       state = (torch.zeros(2, batch_size, encoder_hidden_size).to(device), torch.zeros(2, batch_size, encoder_hidden_size).to(device))
    _, (hn, cn) = self.rnn_lstm(input.float(), state)

    hidden_forward, hidden_backward = torch.split(hn,1,0)
    hidden_cat = torch.cat([hidden_forward.squeeze(0), hidden_backward.squeeze(0)],1)
    mean = self.mean_linear(hidden_cat)
    sigma = self.sigma_linear(hidden_cat)
    std = torch.exp(sigma / 2.)
    # noisy latent vector
    z = mean + sigma * torch.normal(mean.new_zeros(mean.shape), mean.new_ones(mean.shape)).to(device)
    # Adding random noises to standard deviation prevents deterministic results 

    return z, mean, sigma

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self):
      super(DecoderRNN, self).__init__()

      self.init_state = nn.Linear(dim_z, 2 * decoder_hidden_size)
      self.rnn_lstm = nn.LSTM(dim_z + 5, decoder_hidden_size)

      self.fc_layer = nn.Linear(decoder_hidden_size, 6 * num_mix_components + 3) # num_mix*(weight, mu_x, mu_y, sig_x, sig_y, ro_xy) + 3 states

    def forward(self, input, z, state=None):

      if state is None:
        h_0, c_0 = torch.split(torch.tanh(self.init_state(z)), decoder_hidden_size, 1)
        state = (h_0.unsqueeze(0).contiguous(), c_0.unsqueeze(0).contiguous())

      outputs, (h_0,c_0) = self.rnn_lstm(input, state)

      if self.training:
        outputs = self.fc_layer(outputs.view(-1, decoder_hidden_size))
        out_len = longest_seq_len + 1
      else:
        outputs = self.fc_layer(h_0.view(-1, decoder_hidden_size))
        out_len = 1

      separated_outputs = torch.split(outputs,6,1)
      mixture_outputs = torch.stack(separated_outputs[:-1])
      q_outputs = separated_outputs[-1]
      
      pi_cat_probs, mu_x, mu_y, sig_x, sig_y, rho_xy = torch.split(mixture_outputs, 1, 2)      

      pi_cat_probs = F.softmax(pi_cat_probs.transpose(0,1).squeeze()).view(out_len,-1,num_mix_components)
      sig_x = torch.exp(sig_x.transpose(0,1).squeeze()).view(out_len,-1,num_mix_components)
      sig_y = torch.exp(sig_y.transpose(0,1).squeeze()).view(out_len,-1,num_mix_components)
      rho_xy = torch.tanh(rho_xy.transpose(0,1).squeeze()).view(out_len,-1,num_mix_components)
      mu_x = mu_x.transpose(0,1).squeeze().contiguous().view(out_len,-1,num_mix_components)
      mu_y = mu_y.transpose(0,1).squeeze().contiguous().view(out_len,-1,num_mix_components)
      q_cat_probs = F.softmax(q_outputs).view(out_len,-1,3)

      dist = BivariateGaussianMixture(pi_cat_probs,mu_x,mu_y,sig_x,sig_y,rho_xy)

      return dist, q_cat_probs, h_0, c_0

In [None]:
class ReconstructionLoss(nn.Module):
  def forward(self, target, mask, dist, q_logits):
    bi_dist, cat_dist = dist.get_distribution()
    xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, num_mix_components, -1)
    # dx = torch.stack([target[:,:,0]]*num_mix_components,2)
    # dy = torch.stack([target[:,:,1]]*num_mix_components,2)
    probs = torch.sum(cat_dist.probs * torch.exp(bi_dist.log_prob(xy)), 2)
    loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))
    
    loss_pen = -torch.mean(target[:, :, 2:] * q_logits)
    return loss_stroke + loss_pen

In [None]:
class KLDivLoss(nn.Module):
  def forward(self, sigma, mean):
    return -0.5 * torch.mean(1 + sigma - mean ** 2 - torch.exp(sigma))

In [None]:
class Sampler:
  def __init__(self, model):
    self.model = model

  def sample_conditional(self, data, temperature):
    self.model.eval().to(device)
    z, _, _ = self.model.encoder(data,1)
    return self.sample(z, temperature)

  def sample_unconditional(self, temperature):
    self.model.eval().to(device)
    z = torch.randn(1, dim_z, dtype=torch.float)
    return self.sample(z, temperature)

  def sample(self, z, temperature):
    tensor = torch.ones(5, dtype=torch.float32).to(device)
    s = tensor.new_tensor([0, 0, 1, 0, 0])
    seq = [s]
    state = None

    with torch.no_grad():
      for i in range(max_seq_length):
        data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)
        dist, q_cat_probs, h0, c0 = self.model.decoder(data, z, state)
        s = self._sample_step(dist, q_cat_probs, temperature)
        seq.append(s)
        if s[4] == 1:
          break
    
    seq = torch.stack(seq)
    self.plot(seq)
    
  @staticmethod
  def _sample_step(dist, q_cat_probs, temperature):
    dist.set_temperature(temperature)
    bi_dist, cat_dist = dist.get_distribution()
    idx = cat_dist.sample()[0, 0]
    q = torch.distributions.Categorical(logits=q_cat_probs / temperature)
    q_idx = q.sample()[0, 0]
    xy = bi_dist.sample()[0, 0, idx]
    stroke = q_cat_probs.new_zeros(5)
    stroke[:2] = xy
    stroke[q_idx + 2] = 1
    return stroke

  @staticmethod
  def plot(seq: torch.Tensor):
    seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)
    seq[:, 2] = seq[:, 3]
    seq = seq[:, 0:3].detach().cpu().numpy()
    strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)
    for s in strokes:
      plt.plot(s[:, 0], -s[:, 1])
    plt.axis('off')
    plt.show()

In [None]:
class SketchRNN_Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = EncoderRNN().to(device)
    self.decoder = DecoderRNN().to(device)

    self.loss_kl = KLDivLoss().to(device);
    self.loss_rec = ReconstructionLoss().to(device);

  def forward(self, inputs):
    z, mean, sigma = self.encoder(inputs, batch_size)

    z_stack = torch.stack([z]*(longest_seq_len+1))
    inputs = torch.cat([inputs[:-1], z_stack], 2)

    dist, q_cat_probs, _, _ = self.decoder(inputs, z, None)

    return dist, q_cat_probs, mean, sigma

In [None]:
def step(model, batch):
  model.train()

  data = batch[0].to(device).transpose(0, 1)
  mask = batch[1].to(device).transpose(0, 1)
  
  # model forward
  dist, q_cat_probs, mean, sigma = model(data)

  # compute losses
  loss_kl = model.loss_kl(mean, sigma)
  loss_draw = model.loss_rec(data[1:], mask, dist, q_cat_probs)
  loss = loss_kl + loss_draw

  return loss

In [None]:
def train_epoch(model, epoch, batch, valid_dataset, optimizer, sampler):
  loss = step(model, batch)
  loss.backward()
  nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
  optimizer.step()
  
  if epoch%1==0:
    print('epoch',epoch,'loss',loss.item())
  if epoch%100==0:
    data, *_ = valid_dataset[np.random.choice(len(valid_dataset))]
    data = data.unsqueeze(1).to(device)
    sampler.sample_conditional(data,1)

In [None]:
dataset = np.load(data_path+'cat.npz', encoding='latin1', allow_pickle=True)

train_data = StrokesDataset(dataset['train'], 200)
valid_data = StrokesDataset(dataset['valid'], 200)

# rand_sampler = torch.utils.data.RandomSampler(train_data, num_samples=1, replacement=True)
train_loader = DataLoader(train_data, batch_size)
valid_loader = DataLoader(valid_data, batch_size)

In [None]:
model = SketchRNN_Model().to(device)
optimizer = optim.Adam(model.parameters())
sampler = Sampler(model)

In [None]:
for epoch, batch in enumerate(train_loader):
   train_epoch(model, epoch, batch, valid_data, optimizer, sampler)

In [None]:
for epoch, batch in enumerate(train_loader):
  if(epoch==0):
     print(batch)
     break

In [None]:
def plot(seq: torch.Tensor):
  seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)
  seq[:, 2] = seq[:, 3]
  seq = seq[:, 0:3].detach().cpu().numpy()
  strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)
  for s in strokes:
    plt.plot(s[:, 0], -s[:, 1])
  plt.axis('off')
  plt.show()

In [None]:
batch[0][0]