In [52]:
### Ref: https://github.com/loliverhennigh/Convolutional-LSTM-in-Tensorflow

In [2]:
import time
import importlib
import numpy as np
import tensorflow as tf
from tensorflow.contrib.slim import fully_connected as fc
import matplotlib.pyplot as plt 
import glob
from collections import OrderedDict
%matplotlib inline
from datetime import datetime
import os
from pathlib import Path
import layer_def as ld
from BasicConvLSTMCell import BasicConvLSTMCell
importlib.reload(ld)

#hyparameters
num_epochs = 50
batch_size=40
input_dim=3
num_sample = 1500

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [14]:
def make_dataset(type="train"):
    if type=="train": filenames = glob.glob("/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/train/*.tfrecords")
    if type=="val":filenames = glob.glob("/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/val/*.tfrecords")
    if type=="test":filenames = glob.glob("/Users/gongbing/PycharmProjects/video_prediction_savp/data/era5_size_64_64_3_3t_norm/test/*.tfrecords")
    def parser(serialized_example):
            seqs = OrderedDict()
            keys_to_features = {
                # 'width': tf.FixedLenFeature([], tf.int64),
                # 'height': tf.FixedLenFeature([], tf.int64),
                'sequence_length': tf.FixedLenFeature([], tf.int64),
                # 'channels': tf.FixedLenFeature([],tf.int64),
                # 'images/encoded':  tf.FixedLenFeature([], tf.string)
                'images/encoded': tf.VarLenFeature(tf.float32)
            }
            parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
            seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
            images = tf.reshape(seq, [20,64, 64,3], name = "reshape_new")
            seqs["images"] = images
            return seqs
    dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8 * 1024 * 1024)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.apply(tf.contrib.data.map_and_batch(
            parser, batch_size, drop_remainder = True, num_parallel_calls = None))
    
    dataset = dataset.prefetch(batch_size) 
    iterator = dataset.make_initializable_iterator()
    return iterator


class convLSTM(object):

    def __init__(self, lr=1e-2, batch_size=64,nz=0,model_name="convLSTM1",context_frames=10,sequence_length=20):
        # Set hyperparameters
        self.lr = lr
        self.batch_size = batch_size
        self.nz=0
        self.model_name = model_name
        self.context_frames = context_frames
        self.sequence_length = sequence_length
        self.predict_frames = sequence_length -  context_frames
        
        # Build the graph
        self.build()
        # Initialize paramters
        self.sess = tf.InteractiveSession()
        self.sess.run(tf.global_variables_initializer())
        # Summary op
        self.loss_summary = tf.summary.scalar("total_losses", self.total_loss)
        self.summary_op = tf.summary.merge_all()
        self.summary_dir = "./"
        self.output_dir = model_name
        self.checkpoint_dir = self.output_dir + "/checkpoint"
        Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)
        self.train_log_file = self.output_dir + "/train_" + datetime.now().strftime("%Y%m%d-%H%M%S")
        self.val_log_file = self.output_dir + "/val_" + datetime.now().strftime("%Y%m%d-%H%M%S")
        self.train_writer = tf.summary.FileWriter(self.train_log_file, self.sess.graph)
        self.val_writer = tf.summary.FileWriter(self.val_log_file, self.sess.graph)
        
        
    def myModel(self,model_name="convLSTM1",*args):
        modelsDic = {
             "convLSTM1": self.convLSTM_network(*args),
               }
        self.model = modelsDic[model_name]
        return self.model
         
    @staticmethod
    def convLSTM_cell(inputs, hidden, nz=16):
        print("Inputs shape", inputs.shape)
        conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1",activate="leaky_relu")
        print("Encode_1_shape",conv1.shape) 
        conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2",activate="leaky_relu")
        print("Encode 2_shape,",conv2.shape)
        conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3",activate="leaky_relu") 
        print("Encode 3_shape, ", conv3.shape)
        y_0 = conv3
        # conv lstm cell
        with tf.variable_scope('conv_lstm', initializer=tf.random_uniform_initializer(-.01, 0.1)):

            cell = BasicConvLSTMCell(shape = [16, 16], filter_size = [3, 3], num_features=8)
            if hidden is None:
                hidden = cell.zero_state(y_0, tf.float32)
                print ("hidden zero layer",hidden.shape)
            output, hidden  = cell(y_0, hidden)
            print("output for cell:", output)
            
        output_shape = output.get_shape().as_list()
        print ("output_shape,",output_shape)
        
        z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]])
     
        conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, "decode_5",activate="leaky_relu") 
        print("conv5 shape",conv5)
       
        conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, "decode_6",activate="leaky_relu") 
        print("conv6 shape",conv6)

        x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, "decode_7",activate="sigmoid")  # set activation to linear
        print("x hat shape",x_hat)
        return x_hat, hidden
    
    
    def convLSTM_network(self):
        network_template = tf.make_template('network', convLSTM.convLSTM_cell) #make the template to share the variables
        # create network
        x_hat_context = []
        x_hat_predict = []
        seq_start = 1
        hidden = None
        for i in range(self.context_frames):
            if i < seq_start:
                x_1, hidden = network_template(self.x[:, i, :, :, :], hidden)
            else:
                x_1, hidden = network_template(x_1, hidden)
            x_hat_context.append(x_1)
        
        for i in range(self.predict_frames):
            x_1, hidden = network_template(x_1, hidden)
            x_hat_predict.append(x_1)
        
        # pack them all together
        x_hat_context = tf.stack(x_hat_context)
        x_hat_predict = tf.stack(x_hat_predict)
        self.x_hat_context = tf.transpose(x_hat_context, [1, 0, 2, 3, 4]) #change first dim with sec dim
        self.x_hat_predict = tf.transpose(x_hat_predict, [1, 0, 2, 3, 4]) #change first dim with sec dim
        return self.x_hat_context, self.x_hat_predict
        
        
    # Build the netowrk and the loss functions
    def build(self):
        tf.reset_default_graph()
        tf.set_random_seed(12345)
        self.train_iterator = make_dataset(type="train")
        self.val_iterator = make_dataset(type="val")
        self.test_iterator = make_dataset(type="test")
        self.x = tf.placeholder(tf.float32, [None,20,64,64,3])
        self.global_step = tf.train.get_or_create_global_step()
        #ARCHITECTURE
        self.x_hat_context_frames,self.x_hat_predict_frames = self.myModel(model_name=self.model_name)
        self.x_hat = tf.concat([self.x_hat_context_frames,self.x_hat_predict_frames], 1)
        print ("x_hat,shape",self.x_hat)
        
        #Loss calculation
        self.context_frames_loss = tf.reduce_mean(tf.square(self.x[:,:self.context_frames,:,:,0]  - self.x_hat_context_frames[:,:,:,:,0]))
        self.predict_frames_loss = tf.reduce_mean(tf.square(self.x[:,self.context_frames:,:,:,0]  - self.x_hat_predict_frames[:,:,:,:,0]))
        self.total_loss = self.context_frames_loss + self.predict_frames_loss
        
        self.train_op = tf.train.AdamOptimizer(
            lr=self.lr).minimize(self.total_loss,global_step=self.global_step)
        
        # Build a saver
        self.saver = tf.train.Saver(tf.global_variables())
        return

    # Execute the forward and the backward pass
    def run_single_step(self,global_step):
        try:
            train_batch = self.sess.run(self.train_iterator.get_next())
    
            x_hat, train_summary, _, train_losses = self.sess.run([self.x_hat, self.summary_op, self.train_op, self.total_loss], feed_dict={self.x: train_batch["images"]})
            self.train_writer.add_summary(train_summary, global_step)
    
        except tf.errors.OutOfRangeError:
            print("train out of range error")
        
        try:
            val_batch = self.sess.run(self.val_iterator.get_next())
            val_summary, val_losses = self.sess.run([self.summary_op, self.total_loss], feed_dict={self.x: val_batch["images"]})
            self.val_writer.add_summary(val_summary, global_step)
        except tf.errors.OutOfRangeError:
            print("train out of range error")
        
        return train_losses,val_losses


In [15]:
def trainer_and_checkpoint(model_class, learning_rate=1e-2, 
            batch_size=64, num_epoch=100, nz=16, log_step=5):
    #restore the existing checkpoints  
    model = model_class(learning_rate=learning_rate, batch_size=batch_size, nz=nz)
    ckpt = tf.train.get_checkpoint_state(model.checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        #Extract from checkpoint filename
        global_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        sess = tf.Session()
        print("Restore from {}".format(ckpt.model_checkpoint_path))
        #graph = tf.get_default_graph()
        saver = tf.train.Saver(tf.global_variables())
        #saver = tf.train.import_meta_graph(os.path.join(model.checkpoint_dir,'model.ckpt-{}.meta'.format(global_step)))
        saver.restore(sess,tf.train.latest_checkpoint(model.checkpoint_dir))
        loaded_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name)
    else:
        print("Initializer from scratch")
        global_step = model.sess.run(model.global_step)
        model.sess.run(model.train_iterator.initializer)
        model.sess.run(model.val_iterator.initializer)
    # Training loop    
    for epoch in range(num_epoch):

        start_time = time.time()
        # Run an epoch
        for iter in range(num_sample // batch_size):
            print("iter",iter)
            global_step = model.sess.run(model.global_step)
            train_losses,val_losses = model.run_single_step(global_step)
            print ("Train_loss: {}; Val_loss{} for global step {}".format(train_losses,val_losses,global_step))
            checkpoint_path = os.path.join(model.checkpoint_dir, 'model_arc4.ckpt')
            model.saver.save(model.sess,checkpoint_path)
            #global_step = global_step  +1
        end_time = time.time()
        
    print('Done!')
    return model    
    

In [17]:
model_vae2 = trainer_and_checkpoint(convLSTM)

Inputs shape (?, 64, 64, 3)
conv_layer activation function leaky_relu
Encode_1_shape (?, 32, 32, 8)
conv_layer activation function leaky_relu
Encode 2_shape, (?, 32, 32, 8)
conv_layer activation function leaky_relu
Encode 3_shape,  (?, 16, 16, 8)
hidden zero layer (?, 16, 16, 16)
output for cell: Tensor("network/conv_lstm/BasicConvLSTMCell/mul_2:0", shape=(?, 16, 16, 8), dtype=float32)
output_shape, [None, 16, 16, 8]
output_shape Tensor("network/decode_5_trans_conv/stack:0", shape=(4,), dtype=int32)
conv5 shape Tensor("network/decode_5_trans_conv/decode_5_transpose_conv:0", shape=(?, 32, 32, 8), dtype=float32)
output_shape Tensor("network/decode_6_trans_conv/stack:0", shape=(4,), dtype=int32)
conv6 shape Tensor("network/decode_6_trans_conv/decode_6_transpose_conv:0", shape=(?, 32, 32, 8), dtype=float32)
output_shape Tensor("network/decode_7_trans_conv/stack:0", shape=(4,), dtype=int32)
x hat shape Tensor("network/decode_7_trans_conv/sigmoid:0", shape=(?, 64, 64, 3), dtype=float32)
Inpu

conv5 shape Tensor("network_8/decode_5_trans_conv/decode_5_transpose_conv:0", shape=(?, 32, 32, 8), dtype=float32)
output_shape Tensor("network_8/decode_6_trans_conv/stack:0", shape=(4,), dtype=int32)
conv6 shape Tensor("network_8/decode_6_trans_conv/decode_6_transpose_conv:0", shape=(?, 32, 32, 8), dtype=float32)
output_shape Tensor("network_8/decode_7_trans_conv/stack:0", shape=(4,), dtype=int32)
x hat shape Tensor("network_8/decode_7_trans_conv/sigmoid:0", shape=(?, 64, 64, 3), dtype=float32)
Inputs shape (?, 64, 64, 3)
conv_layer activation function leaky_relu
Encode_1_shape (?, 32, 32, 8)
conv_layer activation function leaky_relu
Encode 2_shape, (?, 32, 32, 8)
conv_layer activation function leaky_relu
Encode 3_shape,  (?, 16, 16, 8)
output for cell: Tensor("network_9/conv_lstm/BasicConvLSTMCell/mul_2:0", shape=(?, 16, 16, 8), dtype=float32)
output_shape, [None, 16, 16, 8]
output_shape Tensor("network_9/decode_5_trans_conv/stack:0", shape=(4,), dtype=int32)
conv5 shape Tensor("netw

output_shape Tensor("network_19/decode_6_trans_conv/stack:0", shape=(4,), dtype=int32)
conv6 shape Tensor("network_19/decode_6_trans_conv/decode_6_transpose_conv:0", shape=(?, 32, 32, 8), dtype=float32)
output_shape Tensor("network_19/decode_7_trans_conv/stack:0", shape=(4,), dtype=int32)
x hat shape Tensor("network_19/decode_7_trans_conv/sigmoid:0", shape=(?, 64, 64, 3), dtype=float32)
x_hat,shape Tensor("concat:0", shape=(?, 20, 64, 64, 3), dtype=float32)
Initializer from scratch
iter 0
Train_loss: 0.21427994966506958; Val_loss0.14263123273849487 for global step 0
iter 1
Train_loss: 0.1788034737110138; Val_loss0.1106971725821495 for global step 1
iter 2
Train_loss: 0.12229645252227783; Val_loss0.08845406770706177 for global step 2
iter 3
Train_loss: 0.07126347720623016; Val_loss0.0749092847108841 for global step 3
iter 4
Train_loss: 0.07270807027816772; Val_loss0.0580945685505867 for global step 4
iter 5
Train_loss: 0.05784735083580017; Val_loss0.04535266384482384 for global step 5
i



Train_loss: 0.02111382782459259; Val_loss0.028916671872138977 for global step 32
iter 10
Train_loss: 0.020132645964622498; Val_loss0.020434508100152016 for global step 33
iter 11
Train_loss: 0.02447749301791191; Val_loss0.023730896413326263 for global step 34
iter 12
Train_loss: 0.02405542880296707; Val_loss0.016301894560456276 for global step 35
iter 13
Train_loss: 0.03178892284631729; Val_loss0.011570471338927746 for global step 36
iter 14
Train_loss: 0.025819160044193268; Val_loss0.010554560460150242 for global step 37
iter 15
Train_loss: 0.01980629190802574; Val_loss0.009965412318706512 for global step 38
iter 16
Train_loss: 0.015596813522279263; Val_loss0.029409734532237053 for global step 39
iter 17
Train_loss: 0.008390480652451515; Val_loss0.03737672418355942 for global step 40
iter 18
Train_loss: 0.007606159895658493; Val_loss0.038186706602573395 for global step 41
iter 19
Train_loss: 0.024350382387638092; Val_loss0.031375959515571594 for global step 42
iter 20
Train_loss: 0.02

iter 10
Train_loss: 0.023583944886922836; Val_loss0.023498009890317917 for global step 125
iter 11
Train_loss: 0.028142601251602173; Val_loss0.02691151574254036 for global step 126
iter 12
Train_loss: 0.0256429985165596; Val_loss0.016381576657295227 for global step 127
iter 13
Train_loss: 0.024853430688381195; Val_loss0.008420977741479874 for global step 128
iter 14
Train_loss: 0.020954076200723648; Val_loss0.007855052128434181 for global step 129
iter 15
Train_loss: 0.01712506264448166; Val_loss0.007370474748313427 for global step 130
iter 16
Train_loss: 0.01389509066939354; Val_loss0.02488633245229721 for global step 131
iter 17
Train_loss: 0.006827532313764095; Val_loss0.032306693494319916 for global step 132
iter 18
Train_loss: 0.005533850751817226; Val_loss0.03299407288432121 for global step 133
iter 19
Train_loss: 0.024247024208307266; Val_loss0.028097206726670265 for global step 134
iter 20
Train_loss: 0.02031897008419037; Val_loss0.03188082575798035 for global step 135
iter 21


KeyboardInterrupt: 

In [10]:
#def model_prediction():
model = convLSTM(lr=1e-4, batch_size=64, nz=16)
ckpt = tf.train.get_checkpoint_state(model.checkpoint_dir)
print("model_checkpoint_dir",model.checkpoint_dir)
global_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
#First let's load meta graph and restore weights
sess = tf.Session()  
saver = tf.train.Saver(tf.global_variables())
saver.restore(sess,tf.train.latest_checkpoint(model.checkpoint_dir))
#latest_checkpoints = saver.restore(sess,tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()
#op = sess.graph.get_operations()
loaded_vars = tf.trainable_variables() 
test_iterator = make_dataset(type="test")
sess.run(test_iterator.initializer)
for i in range(1):
    test_batch = sess.run(test_iterator.get_next())
    # print("test_batch",test_batch["images"].shape)
    #op_to_restore.eval(feed_dict={model.x: test_batch["images"]})
    predict_images,  recon_loss = sess.run([model.x_hat, model.total_loss], feed_dict={model.x: test_batch["images"]})
    #plot real and predicted images
    real_img = test_batch["images"][0]
    pred_img = predict_images[0]
    #plt.figure(figsize=(8, 8))
    #plt.imshow(real_img[0])
    print("recon_loss",recon_loss)
    print("real_image",real_img.shape)
    print("predic image", pred_img.shape)

Inputs shape (?, 64, 64, 3)
conv_layer activation function leaky_relu
Encode_1_shape (?, 32, 32, 8)
conv_layer activation function leaky_relu
Encode 2_shape, (?, 32, 32, 8)
conv_layer activation function leaky_relu
Encode 3_shape,  (?, 16, 16, 8)
hidden zero layer (?, 16, 16, 16)
output for cell: Tensor("network/conv_lstm/BasicConvLSTMCell/mul_2:0", shape=(?, 16, 16, 8), dtype=float32)
output_shape, [None, 16, 16, 8]
output_shape Tensor("network/decode_5_trans_conv/stack:0", shape=(4,), dtype=int32)
conv5 shape Tensor("network/decode_5_trans_conv/decode_5_transpose_conv:0", shape=(?, 32, 32, 8), dtype=float32)
output_shape Tensor("network/decode_6_trans_conv/stack:0", shape=(4,), dtype=int32)
conv6 shape Tensor("network/decode_6_trans_conv/decode_6_transpose_conv:0", shape=(?, 32, 32, 8), dtype=float32)
output_shape Tensor("network/decode_7_trans_conv/stack:0", shape=(4,), dtype=int32)
x hat shape Tensor("network/decode_7_trans_conv/sigmoid:0", shape=(?, 64, 64, 3), dtype=float32)
Inpu

conv5 shape Tensor("network_8/decode_5_trans_conv/decode_5_transpose_conv:0", shape=(?, 32, 32, 8), dtype=float32)
output_shape Tensor("network_8/decode_6_trans_conv/stack:0", shape=(4,), dtype=int32)
conv6 shape Tensor("network_8/decode_6_trans_conv/decode_6_transpose_conv:0", shape=(?, 32, 32, 8), dtype=float32)
output_shape Tensor("network_8/decode_7_trans_conv/stack:0", shape=(4,), dtype=int32)
x hat shape Tensor("network_8/decode_7_trans_conv/sigmoid:0", shape=(?, 64, 64, 3), dtype=float32)
Inputs shape (?, 64, 64, 3)
conv_layer activation function leaky_relu
Encode_1_shape (?, 32, 32, 8)
conv_layer activation function leaky_relu
Encode 2_shape, (?, 32, 32, 8)
conv_layer activation function leaky_relu
Encode 3_shape,  (?, 16, 16, 8)
output for cell: Tensor("network_9/conv_lstm/BasicConvLSTMCell/mul_2:0", shape=(?, 16, 16, 8), dtype=float32)
output_shape, [None, 16, 16, 8]
output_shape Tensor("network_9/decode_5_trans_conv/stack:0", shape=(4,), dtype=int32)
conv5 shape Tensor("netw

output_shape Tensor("network_19/decode_5_trans_conv/stack:0", shape=(4,), dtype=int32)
conv5 shape Tensor("network_19/decode_5_trans_conv/decode_5_transpose_conv:0", shape=(?, 32, 32, 8), dtype=float32)
output_shape Tensor("network_19/decode_6_trans_conv/stack:0", shape=(4,), dtype=int32)
conv6 shape Tensor("network_19/decode_6_trans_conv/decode_6_transpose_conv:0", shape=(?, 32, 32, 8), dtype=float32)
output_shape Tensor("network_19/decode_7_trans_conv/stack:0", shape=(4,), dtype=int32)
x hat shape Tensor("network_19/decode_7_trans_conv/sigmoid:0", shape=(?, 64, 64, 3), dtype=float32)
x_hat,shape Tensor("concat:0", shape=(?, 20, 64, 64, 3), dtype=float32)




model_checkpoint_dir convLSTM1/checkpoint
INFO:tensorflow:Restoring parameters from convLSTM1/checkpoint/model_arc4.ckpt-98
recon_loss 0.026403949
real_image (20, 64, 64, 3)
predic image (20, 64, 64, 3)


In [11]:
list(real_img[0][:,:,0])

[array([0.5560357 , 0.5554696 , 0.5550167 , 0.55508465, 0.55508465,
        0.55490345, 0.5546091 , 0.55422413, 0.55320513, 0.5519144 ,
        0.5500576 , 0.5483366 , 0.54668355, 0.54464555, 0.54299253,
        0.54215467, 0.54122627, 0.54025257, 0.5384863 , 0.5368785 ,
        0.53554255, 0.5287492 , 0.5191933 , 0.512717  , 0.45368317,
        0.4614049 , 0.4616087 , 0.45902723, 0.46871904, 0.47924864,
        0.49399012, 0.49879074, 0.5079164 , 0.5121962 , 0.51584196,
        0.51824224, 0.51969147, 0.52066517, 0.51609105, 0.51217353,
        0.5048821 , 0.48690245, 0.47519532, 0.46878695, 0.46280885,
        0.45766857, 0.452166  , 0.44648227, 0.44734275, 0.44849762,
        0.46081614, 0.44849762, 0.4451236 , 0.43957573, 0.4354318 ,
        0.43414107, 0.4350695 , 0.43713015, 0.44000596, 0.45092055,
        0.4548154 , 0.45830262, 0.4615634 , 0.46244654], dtype=float32),
 array([0.55576396, 0.5558998 , 0.55601305, 0.5557413 , 0.5555602 ,
        0.5554922 , 0.5548582 , 0.5539071 ,

In [20]:
list(pred_img[0][:,:,0])

[array([0.49320534, 0.44745713, 0.5766689 , 0.4641738 , 0.60705465,
        0.46950525, 0.5948406 , 0.4638009 , 0.6054449 , 0.46955097,
        0.59434456, 0.4638989 , 0.6055772 , 0.46947005, 0.594272  ,
        0.46380433, 0.6054064 , 0.4694348 , 0.5942071 , 0.46378076,
        0.60534954, 0.46940926, 0.5941551 , 0.4637759 , 0.6052942 ,
        0.46937263, 0.59406495, 0.46375534, 0.6051829 , 0.46929306,
        0.5938719 , 0.46367675, 0.6050081 , 0.46919897, 0.5936955 ,
        0.46361703, 0.6048424 , 0.4691077 , 0.5935205 , 0.46357316,
        0.60464597, 0.46896937, 0.5932163 , 0.46344322, 0.60436386,
        0.46883845, 0.592953  , 0.46333903, 0.60425395, 0.4687915 ,
        0.59288883, 0.46333337, 0.60419893, 0.46876112, 0.5926205 ,
        0.46312538, 0.6049836 , 0.4690384 , 0.59445316, 0.46433222,
        0.60377157, 0.468295  , 0.5906709 , 0.45809314], dtype=float32),
 array([0.42938492, 0.50645626, 0.5498936 , 0.54659444, 0.5514135 ,
        0.55093247, 0.5663991 , 0.54805183,

In [None]:
fig = plt.figure(figsize=(18,6))
gs = gridspec.GridSpec(1, 10)
gs.update(wspace = 0., hspace = 0.)
ts = [0,5,9,10,12,14,16,18,19]
xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))] 
for t in range(len(ts)):
    #if t==0 : ax1=plt.subplot(gs[t])
    ax1 = plt.subplot(gs[t])
    input_image = input_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
    plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
    ax1.title.set_text("t = " + str(ts[t]+1))
    plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
    if t == 0:
        plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
        plt.ylabel("Ground Truth", fontsize=10)
plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
plt.clf()