## Tutorial from https://github.com/shaohua0116/VAE-Tensorflow

In [2]:
import time
import importlib
import numpy as np
import tensorflow as tf
from tensorflow.contrib.slim import fully_connected as fc
import matplotlib.pyplot as plt 
import glob
from collections import OrderedDict
%matplotlib inline
from datetime import datetime
import os
from pathlib import Path
import layer_def as ld
importlib.reload(ld)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


<module 'layer_def' from '/Users/gongbing/PycharmProjects/GAN_practice/layer_def.py'>

In [4]:
train_files = "/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/train"
test_files = "/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/test"
val_files = "/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/train"

In [14]:
num_epochs = 50
batch_size=40
input_dim=3
num_sample = 1500
def make_dataset(type="train"):
    if type=="train": filenames = glob.glob("/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/train/*.tfrecords")
    if type=="val":filenames = glob.glob("/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/val/*.tfrecords")
    if type=="test":filenames = glob.glob("/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/test/*.tfrecords")
    
    def parser(serialized_example):
            seqs = OrderedDict()
            keys_to_features = {
                # 'width': tf.FixedLenFeature([], tf.int64),
                # 'height': tf.FixedLenFeature([], tf.int64),
                'sequence_length': tf.FixedLenFeature([], tf.int64),
                # 'channels': tf.FixedLenFeature([],tf.int64),
                # 'images/encoded':  tf.FixedLenFeature([], tf.string)
                'images/encoded': tf.VarLenFeature(tf.float32)
            }
            parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
            seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
            print("Seq= ",seq.shape)
            images = tf.reshape(seq, [20,64, 64,3], name = "reshape_new")
            seqs["images"] = images
            return seqs
    dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8 * 1024 * 1024)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.apply(tf.contrib.data.map_and_batch(
            parser, batch_size, drop_remainder = True, num_parallel_calls = None))



    iterator = dataset.make_initializable_iterator()
    return iterator

class VariantionalAutoencoder(object):

    def __init__(self, learning_rate=1e-4, batch_size=64, n_z=16):
        # Set hyperparameters
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.n_z = n_z
        # Build the graph
        self.build()
        # Initialize paramters
        self.sess = tf.InteractiveSession()
        self.sess.run(tf.global_variables_initializer())
        # Summary op
        self.loss_summary = tf.summary.scalar("recon_losses", self.recon_loss)
        self.loss_summary = tf.summary.scalar("latent_losses", self.latent_loss)
        self.summary_op = tf.summary.merge_all()
        self.summary_dir = "./"
        self.base_dir = "VAE"
        self.checkpoint_dir =  "VAE" + "/checkpoint"
        Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)
        self.train_log_file = self.base_dir + "/train_" + datetime.now().strftime("%Y%m%d-%H%M%S")
        self.val_log_file = self.base_dir + "/val_" + datetime.now().strftime("%Y%m%d-%H%M%S")
        self.train_writer = tf.summary.FileWriter(self.train_log_file, self.sess.graph)
        self.val_writer = tf.summary.FileWriter(self.val_log_file, self.sess.graph)

         
    def vae_arc1(self):
        
        f1 = fc(self.x, 128, scope='enc_fc1', activation_fn=tf.nn.relu)
        f2 = fc(f1, 64, scope='enc_fc2', activation_fn=tf.nn.relu)
        f3 = fc(f2, 32, scope='enc_fc3', activation_fn=tf.nn.relu)
        self.z_mu = fc(f3, self.n_z, scope='enc_fc4_mu', 
                       activation_fn=None)
        self.z_log_sigma_sq = fc(f3, self.n_z, scope='enc_fc4_sigma', 
                                 activation_fn=None)
        eps = tf.random_normal(shape=tf.shape(self.z_log_sigma_sq),mean=0, stddev=1, dtype=tf.float32)
        
        self.z = self.z_mu + tf.sqrt(tf.exp(self.z_log_sigma_sq)) * eps

        g1 = fc(self.z, 32, scope='dec_fc1', activation_fn=tf.nn.relu)
        g2 = fc(g1, 64, scope='dec_fc2', activation_fn=tf.nn.relu)
        g3 = fc(g2, 128, scope='dec_fc3', activation_fn=tf.nn.relu)
        self.x_hat = fc(g3, input_dim, scope='dec_fc4', activation_fn=tf.sigmoid)
        return 
    

    
    def myModel(self, model_name="vanillaVAE1",*args):
        modelsDic = {
             "vanillaVAE1": self.vae_arc_all(*args),
             "vanillaVAE2": self.vae_arc1(*args)
            
               }
        self.model = modelsDic[model_name]
        return self.model
        
        
    @staticmethod
    def vae_cell(x, l_name=0):
        seq_name = "sq_" + str(l_name) + "_"
      
        conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1",activate="leaky_relu")

        conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name +"encode_2",activate="leaky_relu") 

        conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name +"encode_3",activate="leaky_relu")    

        conv4 = tf.layers.Flatten()(conv3)
        
        conv3_shape = conv3.get_shape().as_list()
      
        z_mu = ld.fc_layer(conv4, hiddens=16, idx= seq_name + "enc_fc4_m")
        z_log_sigma_sq = ld.fc_layer(conv4, hiddens=16, idx= seq_name + "enc_fc4_m"'enc_fc4_sigma')        
        eps = tf.random_normal(shape=tf.shape(z_log_sigma_sq), mean=0, stddev=1, dtype=tf.float32)
        z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps
      
        z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1")

        z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]])

        conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, seq_name +"decode_5",activate="leaky_relu") 
       
        conv6 = ld.transpose_conv_layer(z3, 3, 1, 8, seq_name +"decode_6",activate="leaky_relu") 

        x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name +"decode_8",activate="leaky_relu")  # set activation to linear

        return x_hat, z_mu,z_log_sigma_sq, z
    
    
    
    def vae_arc_all(self):
        X = []
        z_log_sigma_sq_all = []
        z_mu_all = []
        for i in range(20):
            q, z_mu, z_log_sigma_sq, z = VariantionalAutoencoder.vae_cell(self.x[:,i,:,:,:], l_name=i)
            X.append(q)
            z_log_sigma_sq_all.append(z_log_sigma_sq)
            z_mu_all.append(z_mu)
        x_hat = tf.stack(X,axis = 1)
        z_log_sigma_sq_all = tf.stack(z_log_sigma_sq_all,axis = 1)
        z_mu_all = tf.stack(z_mu_all,axis = 1)
        print ("X_hat",x_hat.shape)
        print ("zlog_sigma_sq_all",z_log_sigma_sq_all.shape)
        return x_hat,z_log_sigma_sq_all,z_mu_all
        
        
    # Build the netowrk and the loss functions
    def build(self):
        tf.reset_default_graph()
        self.train_iterator = make_dataset(type="train")
        self.val_iterator = make_dataset(type="val")
        self.test_iterator = make_dataset(type="test")
        self.x = tf.placeholder(tf.float32, [None,20,64,64,3])
        self.global_step = tf.Variable(0, name='global_step', trainable=False)
        self.increment_global_step = tf.assign_add(self.global_step,1,name = 'increment_global_step')
        
        #ARCHITECTURE
        self.x_hat,self.z_log_sigma_sq,self.z_mu = self.myModel()
        #self.x_hat,self.z_log_sigma_sq,self.z_mu = self.vae_arc_all()
        # Loss
        # Reconstruction loss
        # Minimize the cross-entropy loss
#         epsilon = 1e-10
#         recon_loss = -tf.reduce_sum(
#             self.x[:,1:,:,:,:] * tf.log(epsilon+self.x_hat[:,:-1,:,:,:]) + 
#             (1-self.x[:,1:,:,:,:]) * tf.log(epsilon+1-self.x_hat[:,:-1,:,:,:]), 
#             axis=1
#         )

#        self.recon_loss = tf.reduce_mean(recon_loss)
        self.recon_loss = tf.reduce_mean(tf.square(self.x[:,1:,:,:,0]  - self.x_hat[:,:-1,:,:,0]))

        # Latent loss
        # KL divergence: measure the difference between two distributions
        # Here we measure the divergence between 
        # the latent distribution and N(0, 1)
        latent_loss = -0.5 * tf.reduce_sum(
            1 + self.z_log_sigma_sq - tf.square(self.z_mu) - 
            tf.exp(self.z_log_sigma_sq), axis=1)
        self.latent_loss = tf.reduce_mean(latent_loss)

        self.total_loss = self.recon_loss + self.latent_loss
        self.train_op = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate).minimize(self.total_loss,global_step=self.global_step)
        
        # Build a saver
        self.saver = tf.train.Saver(tf.global_variables())
        
        self.losses = {
            'recon_loss': self.recon_loss,
            'latent_loss': self.latent_loss,
            'total_loss': self.total_loss,
        }     

        
        return

    # Execute the forward and the backward pass
    def run_single_step(self，global_step):
        try:
            train_batch = self.sess.run(self.train_iterator.get_next())
            print("Train_batch shape", train_batch["images"].shape)
            x_hat, train_summary, _, train_losses = self.sess.run([self.x_hat, self.summary_op, self.train_op, self.losses], feed_dict={self.x: train_batch["images"]})
            self.train_writer.add_summary(train_summary, global_step)
            print("x_hat.shape", x_hat.shape)
        except tf.errors.OutOfRangeError:
            print("train out of range error")
        
        try:
            val_batch = self.sess.run(self.val_iterator.get_next())
            val_summary, _, val_losses = self.sess.run([self.summary_op,self.train_op, self.losses], feed_dict={self.x: val_batch["images"]})
            self.val_writer.add_summary(val_summary, global_step)
        except tf.errors.OutOfRangeError:
            print("train out of range error")
        
        return train_losses,val_losses

    # x -> x_hat
    def reconstructor(self, x):
        x_hat = self.sess.run(self.x_hat, feed_dict={self.x: x})
        return x_hat

    # z -> x
    def generator(self, z):
        x_hat = self.sess.run(self.x_hat, feed_dict={self.z: z})
        return x_hat
    
    
    # x -> z
    def transformer(self, x):
        z = self.sess.run(self.z, feed_dict={self.x: x})
        return z


In [81]:
def trainer(model_class, learning_rate=1e-4, 
            batch_size=64, num_epoch=100, n_z=16, log_step=5):
    # Create a model    
    model = model_class(learning_rate=learning_rate, batch_size=batch_size, n_z=n_z)

    # Training loop    
    for epoch in range(num_epoch):
        start_time = time.time()
        # Run an epoch
        for iter in range(num_sample // batch_size): 
            step = epoch*(num_sample // batch_size) +  iter
            train_losses,val_losses = model.run_single_step()
            print ("Train_loss: {}; Val_loss{}".format(train_losses,val_losses))
            checkpoint_path = os.path.join(model.checkpoint_dir, 'model.ckpt')
            model.saver.save(model.sess, model.checkpoint_path, global_step =step)
        end_time = time.time()

    print('Done!')
    return model

In [None]:
%load_ext tensorboard.notebook
%tensorboard --logdir=./ --host localhost

In [15]:
def trainer_and_checkpoint(model_class, learning_rate=1e-4, 
            batch_size=64, num_epoch=100, n_z=16, log_step=5):
    
    #restore the existing checkpoints   
    model = model_class(learning_rate=learning_rate, batch_size=batch_size, n_z=n_z)
    ckpt = tf.train.get_checkpoint_state(model.checkpoint_dir)
    
    if ckpt and ckpt.model_checkpoint_path:
        #Extract from checkpoint filename
        global_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        sess = tf.Session()
        print("Restore from {}".format(ckpt.model_checkpoint_path))
        graph = tf.get_default_graph()
        saver = tf.train.Saver(tf.global_variables())
        #saver = tf.train.import_meta_graph(os.path.join(model.checkpoint_dir,'model.ckpt-{}.meta'.format(global_step)))
        saver.restore(sess,tf.train.latest_checkpoint(model.checkpoint_dir))
        loaded_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name)

    else:

        global_step = model.sess.run(model.global_step)
    # Training loop    
    for epoch in range(num_epoch):
        model.sess.run(model.train_iterator.initializer)
        model.sess.run(model.val_iterator.initializer)
        start_time = time.time()
        # Run an epoch
        for iter in range(num_sample // batch_size):

            print ("global step",model.global_step)
            train_losses,val_losses = model.run_single_step(global_step)
            print ("Train_loss: {}; Val_loss{} for global step {}".format(train_losses,val_losses,global_step))
            checkpoint_path = os.path.join(model.checkpoint_dir, 'model_arc4.ckpt')
            model.saver.save(model.sess,checkpoint_path, global_step = global_step)
            global_step = global_step  +1
        end_time = time.time()
        
        
    print('Done!')
    return model    
    

In [None]:
model_vae2 = trainer_and_checkpoint(VariantionalAutoencoder)

In [None]:
#def model_prediction():
model = VariantionalAutoencoder(learning_rate=1e-4, batch_size=64, n_z=16)
ckpt = tf.train.get_checkpoint_state(model.checkpoint_dir)
global_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
#First let's load meta graph and restore weights
sess = tf.Session()  
saver = tf.train.Saver(tf.global_variables())
saver.restore(sess,tf.train.latest_checkpoint(model.checkpoint_dir))
#latest_checkpoints = saver.restore(sess,tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()
#op = sess.graph.get_operations()
#graph.get_tensor_by_name("dec_fc4")
loaded_vars = tf.trainable_variables() 
#loaded_vars
#op_to_restore = graph.get_tensor_by_name("dec_fc4/biases:0")
test_iterator = make_dataset(type="test")
sess.run(test_iterator.initializer)
for i in range(1):
    test_batch = sess.run(test_iterator.get_next())
    # print("test_batch",test_batch["images"].shape)
    #op_to_restore.eval(feed_dict={model.x: test_batch["images"]})
    predict_images,  recon_loss = sess.run([model.x_hat, model.recon_loss], feed_dict={model.x: test_batch["images"]})
    #plot real and predicted images
    real_img = test_batch["images"][0]
    pred_img = predict_images[0]
    #plt.figure(figsize=(8, 8))
    #plt.imshow(real_img[0])
    print("recon_loss",recon_loss)
    print("real_image",real_img.shape)
    print("predic image", pred_img.shape)

In [None]:
list(real_img[0][:,:,0])

In [None]:
list(pred_img[0][:,:,0])

In [None]:
fig = plt.figure(figsize=(18,6))
    gs = gridspec.GridSpec(1, 10)
    gs.update(wspace = 0., hspace = 0.)
    ts = [0,5,9,10,12,14,16,18,19]
    xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
    ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))] 
    for t in range(len(ts)):
        #if t==0 : ax1=plt.subplot(gs[t])
        ax1 = plt.subplot(gs[t])
        input_image = input_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
        plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
        ax1.title.set_text("t = " + str(ts[t]+1))
        plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
        if t == 0:
            plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
            plt.ylabel("Ground Truth", fontsize=10)
    plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
    plt.clf()

In [None]:
plt.figure(figsize=(8, 8))
plt.imshow(real_img[:,:,0])

In [6]:
np.array(real_img).shape

(1, 3, 3)