In [49]:
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from tqdm.notebook import trange

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

%load_ext autoreload
%load_ext tensorboard

%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [50]:
#SOURCE_PATH = '/media/jeromeku/easystore/workspace/deep_learning/tensorflow/notebooks/src/'
if not SOURCE_PATH in sys.path:
    sys.path.append(SOURCE_PATH)

In [51]:
from disentangled.flags import FLAGS
FLAGS.mark_as_parsed()

import disentangled.models as models

In [52]:
#Data Loading

sprites = sprites_datasetv2.SpritesDataset(fake_data=FLAGS.fake_data)

#Sample data
train = iter(sprites.train)
def unmap(example):
    seq, skin_idx,hair_idx,top_idx,pants_idx,act_idx,skin_name,hair_name, top_name,pants_name,act_name = example
    ex_map = dict(seq=seq, 
                  skin_idx=skin_idx, 
                  hair_idx=hair_idx, 
                  top_idx=top_idx, 
                  pants_idx=pants_idx, 
                  act_idx=act_idx, 
                  skin_name=skin_name,
                  hair_name=hair_name, 
                  top_name=top_name, 
                  pants_name=pants_name,
                  act_name=act_name)
    return ex_map

ex = unmap(next(iter(sprites.train)))
ex['pants_name']



<tf.Tensor: id=2397, shape=(), dtype=string, numpy=b'legs/pants/male/white_pants_male.png'>

In [53]:
#Build model
model = models.DisentangledSequentialVAE(
      latent_size_static=FLAGS.latent_size_static,
      latent_size_dynamic=FLAGS.latent_size_dynamic,
      hidden_size=FLAGS.hidden_size, channels=sprites.channels,
      latent_posterior=FLAGS.latent_posterior)

In [54]:
#Optimizer and learning rate schedule
step = tf.Variable(0, dtype=tf.int64, trainable=False) 
schedule = tf.compat.v1.train.cosine_decay(FLAGS.learning_rate, step, FLAGS.max_steps)
optimizer = tf.keras.optimizers.Adam(schedule)

In [55]:
#Tensorboard
timestamp = datetime.strftime(datetime.today(), "%y%m%d_%H%M%S")
logdir = 'logs/train_data'
file_writer = tf.summary.create_file_writer(logdir + timestamp)
file_writer.set_as_default()

In [56]:
#Checkpointing
checkpoint = tf.train.Checkpoint(step=step, optimizer=optimizer, net=model)
ckpt_manager = tf.train.CheckpointManager(checkpoint, './tf_ckpts', max_to_keep=3)

In [78]:
dataset = sprites.train.map(lambda *x: x[0]).shuffle(1000).repeat()


In [79]:
len(list(dataset))

1000

In [84]:
tensors = tf.stack([t for t in dataset], axis=0)

In [86]:
arrs = tensors.numpy()

In [89]:
np.savez('./data/train', *arrs)

In [90]:
arr_objs = np.load('./data/train.npz')

In [95]:
meta_data = {"length": sprites.length, "channels": sprites.channels, "frame_size": sprites.frame_size}

In [83]:
tf.stack(tensors, axis=0).shape

TensorShape([5, 8, 64, 64, 3])

In [81]:
arrs = [t.numpy() for t in tensors]

In [82]:
np.concatenate(arrs, axis=0).shape

(40, 64, 64, 3)

In [77]:
t.shape

TensorShape([8, 64, 64, 3])

In [58]:
dataset = dataset.batch(FLAGS.batch_size).take(FLAGS.max_steps)

In [59]:
inputs = next(iter(dataset))

In [97]:
%%writefile ./src/disentangled/summary.py

def image_summary(seqs, name, step, num=None):
  """Visualizes sequences as TensorBoard summaries.

  Args:
    seqs: A tensor of shape [n, t, h, w, c].
    name: String name of this summary.
    num: Integer for the number of examples to visualize. Defaults to
      all examples.
  """
  seqs = tf.clip_by_value(seqs, 0., 1.)
  seqs = tf.unstack(seqs[:num])
  joined_seqs = [tf.concat(tf.unstack(seq), 1) for seq in seqs]
  joined_seqs = tf.expand_dims(tf.concat(joined_seqs, 0), 0)
  tf.compat.v2.summary.image(
      name,
      joined_seqs,
      max_outputs=1,
      step=step)


def visualize_reconstruction(inputs, reconstruct, num=3, name="reconstruction"):
  """Visualizes the reconstruction of inputs in TensorBoard.

  Args:
    inputs: A tensor of the original inputs, of shape [batch, timesteps,
      h, w, c].
    reconstruct: A tensor of a reconstruction of inputs, of shape
      [batch, timesteps, h, w, c].
    num: Integer for the number of examples to visualize.
    name: String name of this summary.
  """
  reconstruct = tf.clip_by_value(reconstruct, 0., 1.)
  inputs_and_reconstruct = tf.concat((inputs[:num], reconstruct[:num]), axis=0)
  image_summary(inputs_and_reconstruct, name)


def visualize_qualitative_analysis(inputs, model, samples=1, batch_size=3,
                                   length=8):
  """Visualizes a qualitative analysis of a given model.

  Args:
    inputs: A tensor of the original inputs, of shape [batch, timesteps,
      h, w, c].
    model: A DisentangledSequentialVAE model.
    samples: Number of samples to draw from the latent distributions.
    batch_size: Number of sequences to generate.
    length: Number of timesteps to generate for each sequence.
  """
  average = lambda dist: tf.reduce_mean(
      input_tensor=dist.mean(), axis=0)  # avg over samples
  with tf.compat.v1.name_scope("val_reconstruction"):
    reconstruct = functools.partial(model.reconstruct, inputs=inputs,
                                    samples=samples)
    visualize_reconstruction(inputs, average(reconstruct()))
    visualize_reconstruction(inputs, average(reconstruct(sample_static=True)),
                             name="static_prior")
    visualize_reconstruction(inputs, average(reconstruct(sample_dynamic=True)),
                             name="dynamic_prior")
    visualize_reconstruction(inputs, average(reconstruct(swap_static=True)),
                             name="swap_static")
    visualize_reconstruction(inputs, average(reconstruct(swap_dynamic=True)),
                             name="swap_dynamic")

  with tf.compat.v1.name_scope("generation"):
    generate = functools.partial(model.generate, batch_size=batch_size,
                                 length=length, samples=samples)
    image_summary(average(generate(fix_static=True)), "fix_static")
    image_summary(average(generate(fix_dynamic=True)), "fix_dynamic")


def summarize_dist_params(dist, name, step, name_scope="dist_params"):
  """Summarize the parameters of a distribution.

  Args:
    dist: A Distribution object with mean and standard deviation
      parameters.
    name: The name of the distribution.
    name_scope: The name scope of this summary.
  """
  with tf.compat.v1.name_scope(name_scope):
    tf.compat.v2.summary.histogram(
        name="{}/{}".format(name, "mean"),
        data=dist.mean(),
        step=step)
    tf.compat.v2.summary.histogram(
        name="{}/{}".format(name, "stddev"),
        data=dist.stddev(),
        step=step)


def summarize_mean_in_nats_and_bits(inputs, units, name, step,
                                    nats_name_scope="nats",
                                    bits_name_scope="bits_per_dim"):
  """Summarize the mean of a tensor in nats and bits per unit.

  Args:
    inputs: A tensor of values measured in nats.
    units: The units of the tensor with which to compute the mean bits
      per unit.
    name: The name of the tensor.
    nats_name_scope: The name scope of the nats summary.
    bits_name_scope: The name scope of the bits summary.
  """
  mean = tf.reduce_mean(input_tensor=inputs)
  with tf.compat.v1.name_scope(nats_name_scope):
    tf.compat.v2.summary.scalar(
        name,
        mean,
        step=step)
  with tf.compat.v1.name_scope(bits_name_scope):
    tf.compat.v2.summary.scalar(
        name,
        mean / units / tf.math.log(2.),
        step=step)


Overwriting ./src/disentangled/summary.py


In [120]:
%%writefile ./src/disentangled/train.py

def train_step(model, optimizer, dataset, flags, summary_writer):
    for inputs in dataset.prefetch(buffer_size=None):
      with tf.compat.v2.summary.record_if(
          lambda: tf.math.equal(0, global_step % FLAGS.log_steps)):
        tf.compat.v2.summary.histogram(
            "image",
            data=inputs,
            step=tf.compat.v1.train.get_or_create_global_step())

      with tf.GradientTape() as tape:
        features = model.compressor(inputs)  # (batch, timesteps, hidden)
        static_sample, static_posterior = model.sample_static_posterior(
            features, FLAGS.num_samples)  # (samples, batch, latent)
        dynamic_sample, dynamic_posterior = model.sample_dynamic_posterior(
            features, FLAGS.num_samples, static_sample)  # (sampl, N, T, latent)
        likelihood = model.decoder((dynamic_sample, static_sample))

        reconstruction = tf.reduce_mean(  # integrate samples
            input_tensor=likelihood.mean()[:FLAGS.num_reconstruction_samples],
            axis=0)
        visualize_reconstruction(inputs, reconstruction,
                                 name="train_reconstruction")

        static_prior = model.static_prior()
        _, dynamic_prior = model.sample_dynamic_prior(
            FLAGS.num_samples, FLAGS.batch_size, sprites_data.length)

        if FLAGS.enable_debug_logging:
          summarize_dist_params(static_prior, "static_prior")
          summarize_dist_params(static_posterior, "static_posterior")
          summarize_dist_params(dynamic_prior, "dynamic_prior")
          summarize_dist_params(dynamic_posterior, "dynamic_posterior")
          summarize_dist_params(likelihood, "likelihood")

        static_prior_log_prob = static_prior.log_prob(static_sample)
        static_posterior_log_prob = static_posterior.log_prob(static_sample)
        dynamic_prior_log_prob = tf.reduce_sum(
            input_tensor=dynamic_prior.log_prob(dynamic_sample),
            axis=-1)  # sum time
        dynamic_posterior_log_prob = tf.reduce_sum(
            input_tensor=dynamic_posterior.log_prob(dynamic_sample),
            axis=-1)  # sum time
        likelihood_log_prob = tf.reduce_sum(
            input_tensor=likelihood.log_prob(inputs), axis=-1)  # sum time

        if FLAGS.enable_debug_logging:
          with tf.compat.v1.name_scope("log_probs"):
            summarize_mean_in_nats_and_bits(
                static_prior_log_prob, FLAGS.latent_size_static, "static_prior")
            summarize_mean_in_nats_and_bits(
                static_posterior_log_prob, FLAGS.latent_size_static,
                "static_posterior")
            summarize_mean_in_nats_and_bits(
                dynamic_prior_log_prob, FLAGS.latent_size_dynamic *
                sprites_data.length, "dynamic_prior")
            summarize_mean_in_nats_and_bits(
                dynamic_posterior_log_prob, FLAGS.latent_size_dynamic *
                sprites_data.length, "dynamic_posterior")
            summarize_mean_in_nats_and_bits(
                likelihood_log_prob, sprites_data.frame_size ** 2 *
                sprites_data.channels * sprites_data.length, "likelihood")

        elbo = tf.reduce_mean(input_tensor=static_prior_log_prob -
                              static_posterior_log_prob +
                              dynamic_prior_log_prob -
                              dynamic_posterior_log_prob + likelihood_log_prob)
        loss = -elbo
        tf.compat.v2.summary.scalar(
            "elbo",
            elbo,
            step=tf.compat.v1.train.get_or_create_global_step())

      grads = tape.gradient(loss, model.variables)
      grads, global_norm = tf.clip_by_global_norm(grads, FLAGS.clip_norm)
      grads_and_vars = list(zip(grads, model.variables))  # allow reuse in py3
      if FLAGS.enable_debug_logging:
        with tf.compat.v1.name_scope("grads"):
          tf.compat.v2.summary.scalar(
              "global_norm_grads",
              global_norm,
              step=tf.compat.v1.train.get_or_create_global_step())
          tf.compat.v2.summary.scalar(
              "global_norm_grads_clipped",
              tf.linalg.global_norm(grads),
              step=tf.compat.v1.train.get_or_create_global_step())
        for grad, var in grads_and_vars:
          with tf.compat.v1.name_scope("grads"):
            tf.compat.v2.summary.histogram(
                "{}/grad".format(var.name),
                data=grad,
                step=tf.compat.v1.train.get_or_create_global_step())
          with tf.compat.v1.name_scope("vars"):
            tf.compat.v2.summary.histogram(
                var.name,
                data=var,
                step=tf.compat.v1.train.get_or_create_global_step())
      optimizer.apply_gradients(grads_and_vars, global_step)

    is_log_step = global_step.numpy() % FLAGS.log_steps == 0
    is_final_step = global_step.numpy() == FLAGS.max_steps
    if is_log_step or is_final_step:
      checkpoint_manager.save()
      print("ELBO ({}/{}): {}".format(global_step.numpy(), FLAGS.max_steps,
                                      elbo.numpy()))
      with tf.compat.v2.summary.record_if(True):
        val_data = sprites_data.test.take(20)
        inputs = next(iter(val_data.shuffle(20).batch(3)))[0]
        visualize_qualitative_analysis(inputs, model,
                                       FLAGS.num_reconstruction_samples)

    writer.flush()


Writing ./src/disentangled/train.py
