## Tutorial from https://github.com/shaohua0116/VAE-Tensorflow

In [1]:
import time
import numpy as np
import tensorflow as tf
from tensorflow.contrib.slim import fully_connected as fc
import matplotlib.pyplot as plt 
import glob
from collections import OrderedDict
%matplotlib inline
from datetime import datetime
import os

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
train_files = "/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/train"
test_files = "/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/test"
val_files = "/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/train"

In [3]:
num_epochs = 50
batch_size=40
input_dim=3

num_sample = 1500
def make_dataset(type="train"):
    if type=="train": filenames = glob.glob("/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/train/*.tfrecords")
    if type=="val":filenames = glob.glob("/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/val/*.tfrecords")
    if type=="test":filenames = glob.glob("/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/test/*.tfrecords")
    
    def parser(serialized_example):
            seqs = OrderedDict()
            keys_to_features = {
                # 'width': tf.FixedLenFeature([], tf.int64),
                # 'height': tf.FixedLenFeature([], tf.int64),
                'sequence_length': tf.FixedLenFeature([], tf.int64),
                # 'channels': tf.FixedLenFeature([],tf.int64),
                # 'images/encoded':  tf.FixedLenFeature([], tf.string)
                'images/encoded': tf.VarLenFeature(tf.float32)
            }

            parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
            seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
            print("Seq= ",seq.shape)
            images = tf.reshape(seq, [20,64, 64,3], name = "reshape_new")
            seqs["images"] = images
            return seqs
    dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8 * 1024 * 1024)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.apply(tf.contrib.data.map_and_batch(
            parser, batch_size, drop_remainder = True, num_parallel_calls = None))
    #dataset = dataset.map(parser)
    # num_parallel_calls = None if shuffle else 1  # for reproducibility (e.g. sampled subclips from the test set)
    # dataset = dataset.apply(tf.contrib.data.map_and_batch(
    #    _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) #  Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs
    dataset = dataset.prefetch(batch_size)  # Bing: Take the data to buffer inorder to save the waiting time for GPU
    print("dataset",dataset)
    #dataset = dataset.repeat(max_step)
    #dataset = dataset.batch(batch_size)
    #iterator = dataset.make_one_shot_iterator() #One shot iterator will pool all the data once and memery issue
    iterator = dataset.make_initializable_iterator()
    return iterator

class VariantionalAutoencoder(object):

    def __init__(self, learning_rate=1e-4, batch_size=64, n_z=16):
        # Set hyperparameters
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.n_z = n_z

        # Build the graph
        self.build()
        # Initialize paramters
        self.sess = tf.InteractiveSession()
        self.sess.run(tf.global_variables_initializer())
        # Summary op
        self.loss_summary = tf.summary.scalar("losses", self.recon_loss)
        self.summary_op = tf.summary.merge_all()
        self.summary_dir = "./"
        self.train_log_file = self.summary_dir + "/train_"  + datetime.now().strftime("%Y%m%d-%H%M%S")
        self.val_log_file = self.summary_dir + "/val_"  + datetime.now().strftime("%Y%m%d-%H%M%S")
        self.train_writer = tf.summary.FileWriter(self.train_log_file, self.sess.graph)
        self.val_writer = tf.summary.FileWriter(self.val_log_file, self.sess.graph)
        self.sess.run(self.train_iterator.initializer)
        self.sess.run(self.val_iterator.initializer)
        
    def vae_arc1(self):
        
        # Encode
        # x -> z_mean, z_sigma -> z
        f1 = fc(self.x, 128, scope='enc_fc1', activation_fn=tf.nn.relu)
        f2 = fc(f1, 64, scope='enc_fc2', activation_fn=tf.nn.relu)
        f3 = fc(f2, 32, scope='enc_fc3', activation_fn=tf.nn.relu)
        self.z_mu = fc(f3, self.n_z, scope='enc_fc4_mu', 
                       activation_fn=None)
        self.z_log_sigma_sq = fc(f3, self.n_z, scope='enc_fc4_sigma', 
                                 activation_fn=None)
        eps = tf.random_normal(shape=tf.shape(self.z_log_sigma_sq),mean=0, stddev=1, dtype=tf.float32)
        
        self.z = self.z_mu + tf.sqrt(tf.exp(self.z_log_sigma_sq)) * eps

        # Decode
        # z -> x_hat
        g1 = fc(self.z, 32, scope='dec_fc1', activation_fn=tf.nn.relu)
        g2 = fc(g1, 64, scope='dec_fc2', activation_fn=tf.nn.relu)
        g3 = fc(g2, 128, scope='dec_fc3', activation_fn=tf.nn.relu)
        self.x_hat = fc(g3, input_dim, scope='dec_fc4', activation_fn=tf.sigmoid)
        return 

    def vae_arc2(self):
        
        # Encode
        # x -> z_mean, z_sigma -> z
        f3 = fc(self.x, 32, scope='enc_fc1', activation_fn=tf.nn.relu)

        self.z_mu = fc(f3, self.n_z, scope='enc_fc4_mu', 
                       activation_fn=None)
        self.z_log_sigma_sq = fc(f3, self.n_z, scope='enc_fc4_sigma', 
                                 activation_fn=None)
        eps = tf.random_normal(shape=tf.shape(self.z_log_sigma_sq),mean=0, stddev=1, dtype=tf.float32)
        self.z = self.z_mu + tf.sqrt(tf.exp(self.z_log_sigma_sq)) * eps
        
        # Decode
        # z -> x_hat
        g1 = fc(self.z, 32, scope='dec_fc1', activation_fn=tf.nn.relu)
        self.x_hat = fc(g1, input_dim, scope='dec_fc4', activation_fn=tf.sigmoid)
        return  
    
        
    # Build the netowrk and the loss functions
    def build(self):
        
        tf.reset_default_graph()
        self.train_iterator = make_dataset(type="train")
        self.val_iterator = make_dataset(type="val")
        self.test_iterator = make_dataset(type="test")
        self.x = tf.placeholder(tf.float32, [None,20,64,64,3])
        
        #ARCHITECTURE
        self.vae_arc2()

        # Loss
        # Reconstruction loss
        # Minimize the cross-entropy loss
        epsilon = 1e-10
        recon_loss = -tf.reduce_sum(
            self.x[:,1:,:,:,:] * tf.log(epsilon+self.x_hat[:,:-1,:,:,:]) + 
            (1-self.x[:,1:,:,:,:]) * tf.log(epsilon+1-self.x_hat[:,:-1,:,:,:]), 
            axis=1
        )
        self.recon_loss = tf.reduce_mean(recon_loss)

        # Latent loss
        # KL divergence: measure the difference between two distributions
        # Here we measure the divergence between 
        # the latent distribution and N(0, 1)
        latent_loss = -0.5 * tf.reduce_sum(
            1 + self.z_log_sigma_sq - tf.square(self.z_mu) - 
            tf.exp(self.z_log_sigma_sq), axis=1)
        self.latent_loss = tf.reduce_mean(latent_loss)

        self.total_loss = self.recon_loss + self.latent_loss
        self.train_op = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate).minimize(self.total_loss)
        
        # Build a saver
        self.saver = tf.train.Saver(tf.global_variables())
        
        self.losses = {
            'recon_loss': self.recon_loss,
            'latent_loss': self.latent_loss,
            'total_loss': self.total_loss,
        }      # H(x, x_hat) = -\Sigma x*log(x_hat) + (1-x)*log(1-x_hat)

        
        return

    # Execute the forward and the backward pass
    def run_single_step(self,step):
        try:
            train_batch = self.sess.run(self.train_iterator.get_next())
            print("Train_batch shape",train_batch["images"].shape)
            x_hat, train_summary, _, train_losses = self.sess.run([self.x_hat,self.summary_op,self.train_op, self.recon_loss], feed_dict={self.x: train_batch["images"]})
            self.train_writer.add_summary(train_summary, step)
            print("x_hat.shape",x_hat.shape)
        except tf.errors.OutOfRangeError:
            print("train out of range error")
        
        try:
            val_batch = self.sess.run(self.val_iterator.get_next())
            val_summary, _, val_losses = self.sess.run([self.summary_op,self.train_op, self.recon_loss], feed_dict={self.x: val_batch["images"]})
            self.val_writer.add_summary(val_summary, step)
        except tf.errors.OutOfRangeError:
            print("train out of range error")
        
        return train_losses,val_losses

    # x -> x_hat
    def reconstructor(self, x):
        x_hat = self.sess.run(self.x_hat, feed_dict={self.x: x})
        return x_hat

    # z -> x
    def generator(self, z):
        x_hat = self.sess.run(self.x_hat, feed_dict={self.z: z})
        return x_hat
    
    
    # x -> z
    def transformer(self, x):
        z = self.sess.run(self.z, feed_dict={self.x: x})
        return z

In [111]:
def trainer(model_class, learning_rate=1e-4, 
            batch_size=64, num_epoch=100, n_z=16, log_step=5):
    # Create a model    
    model = model_class(learning_rate=learning_rate, batch_size=batch_size, n_z=n_z)

    # Training loop    
    for epoch in range(num_epoch):
        start_time = time.time()
        
        # Run an epoch
        for iter in range(num_sample // batch_size):
            # Get a batch
            step = epoch*(num_sample // batch_size) +  iter
            train_losses,val_losses = model.run_single_step(step=step)
            print ("Train_loss: {}; Val_loss{}".format(train_losses,val_losses))
            checkpoint_path = os.path.join(model.summary_dir, 'model.ckpt')
            model.saver.save(model.sess, checkpoint_path, global_step =step)
        end_time = time.time()
        
        # Log the loss
#         if epoch % log_step == 0:
#             log_str = '[Epoch {}] '.format(epoch)
#             for k, v in self.recon_loss.items():
#                 log_str += '{}: {:.3f}  '.format(k, v)
#             log_str += '({:.3f} sec/epoch)'.format(end_time - start_time)
#             print(log_str)

    
    print('Done!')
    return model

In [None]:
model_vae = trainer(VariantionalAutoencoder)

In [5]:
%load_ext tensorboard.notebook
%tensorboard --logdir=./ --host localhost

Reusing TensorBoard on port 6006 (pid 2527), started 3:45:29 ago. (Use '!kill 2527' to kill it.)

In [None]:

model = VariantionalAutoencoder(learning_rate=1e-4, batch_size=64, n_z=16)
#First let's load meta graph and restore weights
sess=tf.Session()  
saver = tf.train.import_meta_graph('model.ckpt-115.meta')
saver.restore(sess,'model.ckpt-115')
#latest_checkpoints = saver.restore(sess,tf.train.latest_checkpoint('./'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
#op = sess.graph.get_operations()
#graph.get_tensor_by_name("dec_fc4")
loaded_vars = tf.trainable_variables() 
#loaded_vars
op_to_restore = graph.get_tensor_by_name("dec_fc4/biases:0")
test_iterator = make_dataset(type="test")
sess.run(test_iterator.initializer)
test_batch = sess.run(test_iterator.get_next())
# print("test_batch",test_batch["images"].shape)
op_to_restore.eval(feed_dict={model.x: test_batch["images"]})
sess.run([model.x_hat,op_to_restore], feed_dict={model.x: test_batch["images"]})
#https://jhui.github.io/2017/03/08/TensorFlow-variable-sharing/

In [None]:
def train_and_checkpoint(model,manager)

In [64]:
loaded_vars

[<tf.Variable 'enc_fc1/weights:0' shape=(3, 32) dtype=float32_ref>,
 <tf.Variable 'enc_fc1/biases:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'enc_fc4_mu/weights:0' shape=(32, 16) dtype=float32_ref>,
 <tf.Variable 'enc_fc4_mu/biases:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'enc_fc4_sigma/weights:0' shape=(32, 16) dtype=float32_ref>,
 <tf.Variable 'enc_fc4_sigma/biases:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'dec_fc1/weights:0' shape=(16, 32) dtype=float32_ref>,
 <tf.Variable 'dec_fc1/biases:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'dec_fc4/weights:0' shape=(32, 3) dtype=float32_ref>,
 <tf.Variable 'dec_fc4/biases:0' shape=(3,) dtype=float32_ref>,
 <tf.Variable 'enc_fc1/weights:0' shape=(3, 32) dtype=float32_ref>,
 <tf.Variable 'enc_fc1/biases:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'enc_fc4_mu/weights:0' shape=(32, 16) dtype=float32_ref>,
 <tf.Variable 'enc_fc4_mu/biases:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'enc_fc4_sigma/weights:0' sh