In [24]:
from google.colab import drive

import functools
import os

import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML

tf.enable_eager_execution()

drive.mount('/content/drive', force_remount=False)
%cd /content/drive/My\ Drive/ETH/Projects/Google/source
!ls
from constants import Constants as C
from dataset_strokes import TFRecordStroke, TFRecordSingleDiagram
from ink_models import InkRNNSeq2Seq

TF_DATA_PATH = "/content/drive/My Drive/ETH/Projects/Google/data/diagrams_with_strokes/training/diagrams_ramer_0.025.tf_record-?????-of-?????"
MODEL_PATH = "/content/drive/My Drive/ETH/Projects/Google/trained_models/1554460043"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive/ETH/Projects/Google/source
config.py	    ink_models.py	 None0000000.png     utils.py
constants.py	    model_components.py  __pycache__
dataset.py	    model.py		 training_eager.py
dataset_strokes.py  model_utils.py	 training_static.py


In [0]:
def tf_ink_batch_to_strokes(undo_fn, tf_ink_batch, tf_ink_start, tf_seq_len):
  """
  Converts a batch of strokes (tf tensor) to list of strokes after reverting 
  preprocessing steps applied. The return value can be passed to visualization
  functions directly.
  """
  ink_batch = undo_fn(tf_ink_batch.numpy(), tf_ink_start.numpy())
  seq_len = tf_seq_len.numpy()
  # y dimension is mirrored.
  ink_batch[:, :, 1] = -1*ink_batch[:, :, 1]
  # Discard paddings.
  strokes = []
  for i in range(ink_batch.shape[0]):
    strokes.append(ink_batch[i][:seq_len[i]])
  return strokes
  
  
def animate_strokes(strokes, color=None):
  # First set up the figure, the axes, and the plot element
  all_strokes = np.concatenate(strokes, axis=0)
  fig, ax = plt.subplots()
  plt.close()
  ax.set_xlim((all_strokes[:, 0].min(), all_strokes[:, 0].max()))
  ax.set_ylim((all_strokes[:, 1].min(), all_strokes[:, 1].max()))
  ax.set_facecolor((1.0, 1.0, 1.0))

  lines = []
  line_borders = []  # Start and end index of a stroke in the entire drawing.
  current_len = 0
  for index, stroke in enumerate(strokes):
      line_borders.append((current_len, current_len+stroke.shape[0]))
      lines.append(ax.plot([],[],lw=2, color=color)[0])
      current_len += stroke.shape[0] + 1

  def init():
      for line in lines:
          line.set_data([],[])
      return lines

  # animation function
  def animate(i):
    for idx, (line, (line_start, line_end)) in enumerate(zip(lines, line_borders)):
      if i > line_end:
        # All strokes that have been visualized.
        line.set_data(strokes[idx][:, 0], strokes[idx][:, 1])
      else:
        line.set_data(strokes[idx][0:(i-line_start), 0], strokes[idx][0:(i-line_start), 1])
        break
    return lines

  anim = animation.FuncAnimation(fig, animate, init_func=init, frames=all_strokes.shape[0], interval=100, blit=True)
  rc('animation', html='jshtml', embed_limit=128)
  return anim

  # from IPython.display import HTML
  # HTML(anim.to_html5_video(embed_limit=128))

In [26]:
# Create Dataset
train_data = TFRecordSingleDiagram(data_path=TF_DATA_PATH, meta_data_path=None, normalize=False, preprocessing=True)
data_iter = train_data.get_iterator()
input_batch, target_batch = next(data_iter)

# Restore Model
model = InkRNNSeq2Seq(latent_units=128, cell_units=512, cell_layers=1, cell_type=C.LSTM, activation=C.RELU)
checkpoint = tf.train.Checkpoint(model=model)
checkpoint_path = tf.train.latest_checkpoint(MODEL_PATH)
if checkpoint_path is None:
  raise Exception("Checkpoint not found in " + MODEL_PATH)
checkpoint.restore(checkpoint_path)

/content/drive/My Drive/ETH/Projects/Google/data/diagrams_with_strokes/training/diagrams_ramer_0.025.tf_record-?????-of-?????
Meta-data not found.
Skipping statistics...


<tensorflow.python.training.checkpointable.util.NameBasedSaverStatus at 0x7f1bb6827828>

In [27]:
# Evaluate the model.
predictions_likelihood = model(inputs=input_batch["encoder_inputs"], decoder_inputs=input_batch["decoder_inputs"])
loss_dict = model.loss(predictions_likelihood, target_batch, input_batch["seq_len"])
model.log_loss(loss_dict, prefix="[Likelihood] ")
ink_reconstruction = tf.concat([predictions_likelihood["stroke"], predictions_likelihood["pen"]], axis=-1)

predictions_autoregressive = model(inputs=input_batch["encoder_inputs"], decoder_inputs=None, output_len=target_batch["stroke"].shape[1])
loss_dict = model.loss(predictions_autoregressive, target_batch, input_batch["seq_len"])
model.log_loss(loss_dict, prefix="[Autoregressive] ")
ink_autoregressive = tf.concat([predictions_autoregressive["stroke"], predictions_autoregressive["pen"]], axis=-1)

[Likelihood] Total: 7.5348 	pen: 7.5342 	stroke: 0.0005 	
[Autoregressive] Total: 7.5341 	pen: 7.5336 	stroke: 0.0005 	


In [23]:
# Animate ground-truth sample.
gt_strokes = tf_ink_batch_to_strokes(train_data.np_undo_preprocessing, input_batch["encoder_inputs"], input_batch["start_point"], input_batch["seq_len"])
gt_anim = animate_strokes(gt_strokes)
gt_anim

Output hidden; open in https://colab.research.google.com to view.

In [28]:
# Animate model reconstruction.
likelihood_strokes = tf_ink_batch_to_strokes(train_data.np_undo_preprocessing, ink_reconstruction, input_batch["start_point"], input_batch["seq_len"])
reconstruction_anim = animate_strokes(likelihood_strokes[:9])
reconstruction_anim