In [247]:
import time
import importlib
import numpy as np
import tensorflow as tf
from tensorflow.contrib.slim import fully_connected as fc
import matplotlib.pyplot as plt 
import glob
from collections import OrderedDict
%matplotlib inline
from datetime import datetime
import os
from pathlib import Path
import layer_def as ld
import BasicConvLSTMCell
#from BasicConvLSTMCell import BasicConvLSTMCell
from mcnet_ops import *
from mcnet_utils import *
importlib.reload(ld)
importlib.reload(BasicConvLSTMCell)
from BasicConvLSTMCell import BasicConvLSTMCell
train_files = "/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/train"
test_files = "/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/test"
val_files = "/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/train"
#hyparameters
num_epochs = 50
input_dim=3
num_sample = 1500
alpha = 1
beta = 0.02
batch_size = 32

In [248]:
def make_dataset(type="train",batch_size=batch_size):
    if type=="train": filenames = glob.glob("/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/train/*.tfrecords")
    if type=="val":filenames = glob.glob("/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/val/*.tfrecords")
    if type=="test":filenames = glob.glob("/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/test/*.tfrecords")
    def parser(serialized_example):
            seqs = OrderedDict()
            keys_to_features = {
                # 'width': tf.FixedLenFeature([], tf.int64),
                # 'height': tf.FixedLenFeature([], tf.int64),
                'sequence_length': tf.FixedLenFeature([], tf.int64),
                # 'channels': tf.FixedLenFeature([],tf.int64),
                # 'images/encoded':  tf.FixedLenFeature([], tf.string)
                'images/encoded': tf.VarLenFeature(tf.float32)
            }
            parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
            seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
            images = tf.reshape(seq, [20,64, 64,3], name = "reshape_new")
            seqs["images"] = images
            return seqs
    dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8 * 1024 * 1024)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.apply(tf.contrib.data.map_and_batch(
            parser, batch_size, drop_remainder = True, num_parallel_calls = None))
    
    dataset = dataset.prefetch(batch_size) 
    iterator = dataset.make_initializable_iterator()
    return iterator

In [332]:
class MCNET(object):
    def __init__(self, learning_rate=1e-2, image_size=[64,64], batch_size=32, c_dim=3,
                 checkpoint_dir=None, is_train=True,context_frames=10,sequence_length=20,model_name="mcnet"):

        self.batch_size = batch_size
        self.image_size = image_size
        self.is_train = is_train
        self.model_name = model_name
        self.lr = learning_rate
        
        self.gf_dim = 64
        self.df_dim = 64

        self.c_dim = c_dim
        self.context_frames = context_frames # context_frames=10 conresponds to K
        self.K = self.context_frames
        self.sequence_length = 20
        self.predict_frames = sequence_length - context_frames #predict_frames corresponds to T
        self.T = self.predict_frames
        
        self.diff_shape = [self.batch_size, self.context_frames-1, self.image_size[0],
                           self.image_size[1], c_dim]
        self.target_shape = [self.batch_size, self.sequence_length, self.image_size[0], self.image_size[1],c_dim]
        self.xt_shape = [self.batch_size, self.image_size[0], self.image_size[1], c_dim]
        print("target shape inti:",self.target_shape)

        self.build()
        
        # Initialize paramters
        self.sess = tf.InteractiveSession()
        self.sess.run(tf.global_variables_initializer())
        # Summary op
        #self.loss_summary = tf.summary.scalar("total_losses", self.total_loss)
        self.summary_op = tf.summary.merge_all()
        self.summary_dir = "./"
        self.base_dir = model_name
        self.checkpoint_dir = self.base_dir + "/checkpoint"
        Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)
        self.train_log_file = self.base_dir + "/train_" + datetime.now().strftime("%Y%m%d-%H%M%S")
        self.val_log_file = self.base_dir + "/val_" + datetime.now().strftime("%Y%m%d-%H%M%S")
        self.train_writer = tf.summary.FileWriter(self.train_log_file, self.sess.graph)
        self.val_writer = tf.summary.FileWriter(self.val_log_file, self.sess.graph)
        
        
    def build(self):
        tf.reset_default_graph()
        tf.set_random_seed(12345)
        self.train_iterator = make_dataset(type="train")
        self.val_iterator = make_dataset(type="val")
        self.test_iterator = make_dataset(type="test")
        self.global_step = tf.train.get_or_create_global_step()
        
        #ARCHITECTURE
        self.x = tf.placeholder(tf.float32, self.target_shape, name='target')
  
        #self.xt = tf.placeholder(tf.float32, self.xt_shape, name='xt')
        self.xt = self.x[:,self.context_frames-1,:,:,:]

        self.diff_in = tf.placeholder(tf.float32, self.diff_shape, name='diff_in')
        diff_in_all = []
        for t in range(1,self.context_frames):
            prev = self.x[:,t-1,:,:,:]
            next = self.x[:,t,:,:,:]
            diff_in = tf.reshape(next - prev,[self.batch_size,1,self.image_size[0],self.image_size[1],-1])
            diff_in_all.append(diff_in)
 
        self.diff_in = tf.concat(axis=1,values=diff_in_all)
 
            
        
        cell = BasicConvLSTMCell([self.image_size[0]/8, self.image_size[1]/8],[3, 3], 256)
    
        pred = self.forward(self.diff_in, self.xt, cell)
      
        #Bing++++++++++
        #self.G = tf.concat(axis=3,values=pred)
        self.G = tf.concat(axis=1,values=pred)
        #Bing---------
        if self.is_train:
            #Bing+++++++++++++
            #true_sim = inverse_transform(self.target[:,:,:,self.K:,:])
            true_sim = self.x[:,self.K:,:,:,:]
            #Bing--------------
            #Bing: the following make sure the channel is three dimension, if the channle is 3 then will be duplicated
            if self.c_dim == 1: true_sim = tf.tile(true_sim,[1,1,1,1,3])
            #Bing+++++++++++++
            #Bing: the raw inputs shape is [batch_size, image_size[0],self.image_size[1], num_seq, channel]. tf.transpose will transpoe the shape into 
            #[batch size*num_seq, image_size0, image_size1, channels], for our era5 case, we do not need transpose
            #true_sim = tf.reshape(tf.transpose(true_sim,[0,3,1,2,4]),
            #                             [-1, self.image_size[0],
            #                              self.image_size[1], 3])
            true_sim = tf.reshape(true_sim,[-1, self.image_size[0],self.image_size[1], 3]) 
            #Bing--------------
            
        #Bing+++++++++++++    
        #gen_sim = inverse_transform(self.G)
        gen_sim = self.G
    
        if self.c_dim == 1: gen_sim = tf.tile(gen_sim,[1,1,1,1,3])
        #gen_sim = tf.reshape(tf.transpose(gen_sim,[0,3,1,2,4]),
        #                                [-1, self.image_size[0],
        #                                self.image_size[1], 3])
        
        gen_sim = tf.reshape(gen_sim,[-1, self.image_size[0],self.image_size[1], 3])
        
        #Bing+++++++++++++ 
        #Bing:the shape of the layer will be channels*num_seq, why ?   
        #binput = tf.reshape(self.target[:,:,:,:self.K,:],
        #                  [self.batch_size, self.image_size[0],
        #                   self.image_size[1], -1])
        binput = tf.reshape(tf.transpose(self.x[:,:self.K,:,:,:],[0,1,2,3,4]),
                          [self.batch_size, self.image_size[0],
                           self.image_size[1], -1])
        #Bing--------------
        btarget = tf.reshape(tf.transpose(self.x[:,self.K:,:,:,:],[0,1,2,3,4]),
                           [self.batch_size, self.image_size[0],
                            self.image_size[1], -1])
        bgen = tf.reshape(self.G,[self.batch_size,
                                self.image_size[0],
                                self.image_size[1], -1])
        
        good_data = tf.concat(axis=3,values=[binput,btarget])
        gen_data  = tf.concat(axis=3,values=[binput,bgen])
        self.gen_data = gen_data

        with tf.variable_scope("DIS", reuse=False):
            self.D, self.D_logits = self.discriminator(good_data)

        with tf.variable_scope("DIS", reuse=True):
            self.D_, self.D_logits_ = self.discriminator(gen_data)
        
        self.L_p = tf.reduce_mean(
          tf.square(self.G-self.x[:,self.K:,:,:,:]))
        
        self.L_gdl = gdl(gen_sim, true_sim, 1.)
        self.L_img = self.L_p + self.L_gdl

        self.d_loss_real = tf.reduce_mean(
          tf.nn.sigmoid_cross_entropy_with_logits(
              logits=self.D_logits, labels=tf.ones_like(self.D)
          ))
        self.d_loss_fake = tf.reduce_mean(
          tf.nn.sigmoid_cross_entropy_with_logits(
              logits=self.D_logits_, labels=tf.zeros_like(self.D_)
          ))
        self.d_loss = self.d_loss_real + self.d_loss_fake
        self.L_GAN = tf.reduce_mean(
          tf.nn.sigmoid_cross_entropy_with_logits(
              logits=self.D_logits_, labels=tf.ones_like(self.D_)
          ))

        self.loss_sum = tf.summary.scalar("L_img", self.L_img)
        self.L_p_sum = tf.summary.scalar("L_p", self.L_p)
        self.L_gdl_sum = tf.summary.scalar("L_gdl", self.L_gdl)
        self.L_GAN_sum = tf.summary.scalar("L_GAN", self.L_GAN)
        self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
        self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real)
        self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake)
        #Bing ++++++++
        self.total_loss = alpha*self.L_img+beta*self.L_GAN
        self.total_loss_sum = tf.summary.scalar("total_loss", self.total_loss)
        self.g_sum = tf.summary.merge([self.L_p_sum,
                              self.L_gdl_sum, self.loss_sum,
                              self.L_GAN_sum])
        self.d_sum = tf.summary.merge([self.d_loss_real_sum, self.d_loss_sum,
                              self.d_loss_fake_sum])
     
        #Bing ---------
        self.t_vars = tf.trainable_variables()
        self.g_vars = [var for var in self.t_vars if 'DIS' not in var.name]
        self.d_vars = [var for var in self.t_vars if 'DIS' in var.name]
        num_param = 0.0
        for var in self.g_vars:
            num_param += int(np.prod(var.get_shape()));
        print("Number of parameters: %d"%num_param)
        #Bing ++++++++
        #self.saver = tf.train.Saver(max_to_keep=10)
        #Bing --------


        #Training
        self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5).minimize(
        self.d_loss, var_list=self.d_vars)
        self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5).minimize(
            alpha*self.L_img+beta*self.L_GAN, var_list=self.g_vars)
    

    def forward(self, diff_in, xt, cell):
        # Initial state
        state = tf.zeros([self.batch_size, self.image_size[0]/8,
                      self.image_size[1]/8, 512])
        reuse = False
        # Encoder
        #Bing++++++++++++++++++++++++++++
        for t in range(self.K-1):
            
            enc_h, res_m = self.motion_enc(diff_in[:,t,:,:,:], reuse=reuse)
            h_dyn, state = cell(enc_h, state, scope='lstm', reuse=reuse)
            reuse = True
        pred = []
       # Decoder
        for t in range(self.T):
            if t == 0:
                h_cont, res_c = self.content_enc(xt, reuse=False)
                h_tp1 = self.comb_layers(h_dyn, h_cont, reuse=False)
                res_connect = self.residual(res_m, res_c, reuse=False)
                x_hat = self.dec_cnn(h_tp1, res_connect, reuse=False)
             
                
            else:
                
                enc_h, res_m = self.motion_enc(diff_in, reuse=True)
                h_dyn, state = cell(enc_h, state, scope='lstm', reuse=True)
                h_cont, res_c = self.content_enc(xt, reuse=reuse)
                h_tp1 = self.comb_layers(h_dyn, h_cont, reuse=True)
                res_connect = self.residual(res_m,res_c, reuse=True)
                x_hat = self.dec_cnn(h_tp1, res_connect, reuse=True)
        
          
           
            x_hat_gray = x_hat
            xt_gray = xt
          
            diff_in = x_hat_gray - xt_gray
            xt = x_hat
            #Bing++++++++++++++++++++++++++++
            #pred.append(tf.reshape(x_hat,[self.batch_size, self.image_size[0],
            #                        self.image_size[1], 1, self.c_dim]))
            pred.append(tf.reshape(x_hat,[self.batch_size,1, self.image_size[0],
                                    self.image_size[1], self.c_dim]))
             #Bing----------------
        return pred

    def motion_enc(self, diff_in, reuse):
        res_in = []
        
        conv1 = relu(conv2d(diff_in, output_dim=self.gf_dim, k_h=5, k_w=5,
                            d_h=1, d_w=1, name='dyn1_conv1', reuse=reuse))
        res_in.append(conv1)
        pool1 = MaxPooling(conv1, [2,2])

        conv2 = relu(conv2d(pool1, output_dim=self.gf_dim*2, k_h=5, k_w=5,
                            d_h=1, d_w=1, name='dyn_conv2',reuse=reuse))
        res_in.append(conv2)
        pool2 = MaxPooling(conv2, [2,2])

        conv3 = relu(conv2d(pool2, output_dim=self.gf_dim*4, k_h=7, k_w=7,
                            d_h=1, d_w=1, name='dyn_conv3',reuse=reuse))
        res_in.append(conv3)
        pool3 = MaxPooling(conv3, [2,2])
        return pool3, res_in

    def content_enc(self, xt, reuse):
        res_in  = []
        conv1_1 = relu(conv2d(xt, output_dim=self.gf_dim, k_h=3, k_w=3,
                              d_h=1, d_w=1, name='cont_conv1_1',reuse=reuse))
        conv1_2 = relu(conv2d(conv1_1, output_dim=self.gf_dim, k_h=3, k_w=3,
                              d_h=1, d_w=1, name='cont_conv1_2',reuse=reuse))
        res_in.append(conv1_2)
        pool1 = MaxPooling(conv1_2, [2,2])

        conv2_1 = relu(conv2d(pool1, output_dim=self.gf_dim*2, k_h=3, k_w=3,
                              d_h=1, d_w=1, name='cont_conv2_1',reuse=reuse))
        conv2_2 = relu(conv2d(conv2_1, output_dim=self.gf_dim*2, k_h=3, k_w=3,
                              d_h=1, d_w=1, name='cont_conv2_2',reuse=reuse))
        res_in.append(conv2_2)
        pool2 = MaxPooling(conv2_2, [2,2])

        conv3_1 = relu(conv2d(pool2, output_dim=self.gf_dim*4, k_h=3, k_w=3,
                              d_h=1, d_w=1, name='cont_conv3_1', reuse=reuse))
        conv3_2 = relu(conv2d(conv3_1, output_dim=self.gf_dim*4, k_h=3, k_w=3,
                              d_h=1, d_w=1, name='cont_conv3_2', reuse=reuse))
        conv3_3 = relu(conv2d(conv3_2, output_dim=self.gf_dim*4, k_h=3, k_w=3,
                              d_h=1, d_w=1, name='cont_conv3_3',reuse=reuse))
        res_in.append(conv3_3)
        pool3 = MaxPooling(conv3_3, [2,2])
        return pool3, res_in

    def comb_layers(self, h_dyn, h_cont, reuse=False):
        comb1 = relu(conv2d(tf.concat(axis=3,values=[h_dyn, h_cont]),
                            output_dim=self.gf_dim*4, k_h=3, k_w=3,
                            d_h=1, d_w=1, name='comb1',reuse=reuse))
        comb2 = relu(conv2d(comb1, output_dim=self.gf_dim*2, k_h=3, k_w=3,
                            d_h=1, d_w=1, name='comb2', reuse=reuse))
        h_comb = relu(conv2d(comb2, output_dim=self.gf_dim*4, k_h=3, k_w=3,
                             d_h=1, d_w=1, name='h_comb', reuse=reuse))
        return h_comb

    def residual(self, input_dyn, input_cont, reuse=False):
        n_layers = len(input_dyn)
        res_out = []
        for l in range(n_layers):
            input_ = tf.concat(axis=3,values=[input_dyn[l],input_cont[l]])
            out_dim = input_cont[l].get_shape()[3]
            res1 = relu(conv2d(input_, output_dim=out_dim,
                             k_h=3, k_w=3, d_h=1, d_w=1,
                             name='res'+str(l)+'_1', reuse=reuse))
            res2 = conv2d(res1, output_dim=out_dim, k_h=3, k_w=3,
                        d_h=1, d_w=1, name='res'+str(l)+'_2', reuse=reuse)
            res_out.append(res2)
        return res_out

    def dec_cnn(self, h_comb, res_connect, reuse=False):
       
        
        shapel3 = [self.batch_size, int(self.image_size[0]/4),
                   int(self.image_size[1]/4), self.gf_dim*4]
        shapeout3 = [self.batch_size, int(self.image_size[0]/4),
                     int(self.image_size[1]/4), self.gf_dim*2]
        depool3 = FixedUnPooling(h_comb, [2,2])
        deconv3_3 = relu(deconv2d(relu(tf.add(depool3, res_connect[2])),
                                  output_shape=shapel3, k_h=3, k_w=3,
                                  d_h=1, d_w=1, name='dec_deconv3_3', reuse=reuse))
        deconv3_2 = relu(deconv2d(deconv3_3, output_shape=shapel3, k_h=3, k_w=3,
                                  d_h=1, d_w=1, name='dec_deconv3_2', reuse=reuse))
        deconv3_1 = relu(deconv2d(deconv3_2, output_shape=shapeout3, k_h=3, k_w=3,
                                  d_h=1, d_w=1, name='dec_deconv3_1', reuse=reuse))

        shapel2 = [self.batch_size, int(self.image_size[0]/2),
                   int(self.image_size[1]/2), self.gf_dim*2]
        shapeout3 = [self.batch_size, int(self.image_size[0]/2),
                     int(self.image_size[1]/2), self.gf_dim]
        depool2 = FixedUnPooling(deconv3_1, [2,2])
        deconv2_2 = relu(deconv2d(relu(tf.add(depool2, res_connect[1])),
                                  output_shape=shapel2, k_h=3, k_w=3,
                                  d_h=1, d_w=1, name='dec_deconv2_2', reuse=reuse))
        deconv2_1 = relu(deconv2d(deconv2_2, output_shape=shapeout3, k_h=3, k_w=3,
                                  d_h=1, d_w=1, name='dec_deconv2_1', reuse=reuse))

        shapel1 = [self.batch_size, self.image_size[0],
                   self.image_size[1], self.gf_dim]
        shapeout1 = [self.batch_size, self.image_size[0],
                     self.image_size[1], self.c_dim]
        depool1 = FixedUnPooling(deconv2_1, [2,2])
        deconv1_2 = relu(deconv2d(relu(tf.add(depool1, res_connect[0])),
                         output_shape=shapel1, k_h=3, k_w=3, d_h=1, d_w=1,
                         name='dec_deconv1_2', reuse=reuse))
        xtp1 = tanh(deconv2d(deconv1_2, output_shape=shapeout1, k_h=3, k_w=3,
                             d_h=1, d_w=1, name='dec_deconv1_1', reuse=reuse))
        return xtp1

    def discriminator(self, image):
        h0 = lrelu(conv2d(image, self.df_dim, name='dis_h0_conv'))
        h1 = lrelu(batch_norm(conv2d(h0, self.df_dim*2, name='dis_h1_conv'),
                              "bn1"))
        h2 = lrelu(batch_norm(conv2d(h1, self.df_dim*4, name='dis_h2_conv'),
                              "bn2"))
        h3 = lrelu(batch_norm(conv2d(h2, self.df_dim*8, name='dis_h3_conv'),
                              "bn3"))
        h = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'dis_h3_lin')

        return tf.nn.sigmoid(h), h

    def save(self, sess, checkpoint_dir, step):
        model_name = "MCNET.model"

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(sess,
                        os.path.join(checkpoint_dir, model_name),
                        global_step=step)

    def load(self, sess, checkpoint_dir, model_name=None):
        print(" [*] Reading checkpoints...")
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            if model_name is None: model_name = ckpt_name
            self.saver.restore(sess, os.path.join(checkpoint_dir, model_name))
            print(" Loaded model: "+str(model_name))
            return True, model_name
        else:
            return False, None
        
    
    # Execute the forward and the backward pass
    def run_single_step(self,global_step):
        print ("global_step:",global_step)
        try:
            train_batch = self.sess.run(self.train_iterator.get_next())
            x = self.sess.run([self.x], feed_dict={self.x: train_batch["images"]})
            _, g_sum = self.sess.run([self.g_optim, self.g_sum],feed_dict={self.x: train_batch["images"]})
            _, d_sum = self.sess.run([self.d_optim, self.d_sum],feed_dict={self.x: train_batch["images"]})
            
            gen_data,train_total_loss = self.sess.run([self.gen_data,self.total_loss],feed_dict={self.x: train_batch["images"]})
            #print("gen_data:",gen_data[0][0])
            #_, train_summary, d_losses = self.sess.run([self.d_optim, self.summary_op, self.d_loss], feed_dict={self.x: train_batch["images"]})
           # _, train_summary, g_losses = self.sess.run([self.g_optim, self.summary_op, self.L_GAN], feed_dict={self.x: train_batch["images"]})
            #self.train_writer.add_summary(train_summary, global_step)
    
        except tf.errors.OutOfRangeError:
            print("train out of range error")

        try:
            val_batch = self.sess.run(self.val_iterator.get_next())
            val_total_loss = self.sess.run([self.total_loss], feed_dict={self.x: val_batch["images"]})
            #self.val_writer.add_summary(val_summary, global_step)
        except tf.errors.OutOfRangeError:
             print("train out of range error")
        
        return train_total_loss, val_total_loss

    
    def trainer(self):        
        ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
            sess = tf.Session()
            print("Restore from {}".format(ckpt.model_checkpoint_path))
            #graph = tf.get_default_graph()
            saver = tf.train.Saver(tf.global_variables())
            #saver = tf.train.import_meta_graph(os.path.join(model.checkpoint_dir,'model.ckpt-{}.meta'.format(global_step)))
            saver.restore(sess,tf.train.latest_checkpoint(self.checkpoint_dir))
            #loaded_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name)
        else:
            print("Initializer from scratch")
            global_step = self.sess.run(self.global_step)
           # Training loop
            model.sess.run(self.train_iterator.initializer)
            model.sess.run(self.val_iterator.initializer)
        for epoch in range(num_epoch):
            start_time = time.time()
            # Run an epoch
            for iter in range(num_sample // self.batch_size):
                print("iter",iter)
                global_step = self.sess.run(self.global_step)
                train_total_loss, val_total_loss = self.run_single_step(global_step)
                print ("Train_loss: {}; Val_loss{} for global step {}".format(train_total_loss,val_total_loss,global_step))
            end_time = time.time()
        print('Done!')
        return model


In [333]:
mcnet_model= model_class()
model_vae2 = mcnet_model.trainer()