In [1]:
import tensorflow as tf
import utils as util
import numpy as np
import os


def cnn_encoder_layer(data, filter_layer, strides):
    """
    :param data: the input data, when it is the first layer is 5 * 30 * 30 * 3, the second layer is 30 * 30 * 32,
                 the third layer is 15 * 15 * 64, the fourth layer is 8 * 8 * 128
    :param filter_layer:
    :param strides:
    :return: the result after conv, the first layer is 30 * 30 * 32, the second layer is 15 * 15 * 64, the third layer
             is 8 * 8 * 128, the final layer is 4 * 4 * 256
    """

    result = tf.nn.conv2d(
        input=data,
        filters=filter_layer,
        strides=strides,
        padding="SAME")
    return tf.nn.selu(result)


def tensor_variable(shape, name):
    """
    Tensor variable declaration initialization
    :param shape:
    :param name:
    :return:
    """
    variable = tf.Variable(tf.zeros(shape), name=name)
    variable = tf.compat.v1.get_variable(name, shape=shape, initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
    
    return variable


def cnn_encoder(data):
    """

    :param data: the input data size is 5 * 30 * 30 * 3
    :return:
    """
    # the first layer,the output size is 30 * 30 * 32
    filter1 = tensor_variable([3, 3, 3, 32], "filter1")
    strides1 = (1, 1, 1, 1)
    cnn1_out = cnn_encoder_layer(data, filter1, strides1)

    # the second layer, the output size is 15 * 15 * 64
    filter2 = tensor_variable([3, 3, 32, 64], "filter2")
    strides2 = (1, 2, 2, 1)
    cnn2_out = cnn_encoder_layer(cnn1_out, filter2, strides2)

    # the third layer, the output size is 8 * 8 * 128
    filter3 = tensor_variable([2, 2, 64, 128], "filter3")
    strides3 = (1, 2, 2, 1)
    cnn3_out = cnn_encoder_layer(cnn2_out, filter3, strides3)

    # the fourth layer, the output size is 4 * 4 * 256
    filter4 = tensor_variable([2, 2, 128, 256], "filter4")
    strides4 = (1, 2, 2, 1)
    cnn4_out = cnn_encoder_layer(cnn3_out, filter4, strides4)

    return cnn1_out, cnn2_out, cnn3_out, cnn4_out


def cnn_lstm_attention_layer(input_data, layer_number):
    """

    :param input_data:
    :param layer_number:
    :return:
    """
    convlstm_layer = tf.contrib.rnn.ConvLSTMCell(
        conv_ndims=2,
        input_shape=[input_data.shape[2], input_data.shape[3], input_data.shape[4]],
        output_channels=input_data.shape[-1],
        kernel_shape=[2, 2],
        use_bias=True,
        skip_connection=False,
        forget_bias=1.0,
        initializers=None,
        name="conv_lstm_cell" + str(layer_number))

    outputs, state = tf.compat.v1.nn.dynamic_rnn(convlstm_layer, input_data, dtype=input_data.dtype)

    # attention based on inner-product between feature representation of last step and other steps
    attention_w = []
    for k in range(util.step_max):
        attention_w.append(tf.reduce_sum(input_tensor=tf.multiply(outputs[0][k], outputs[0][-1])) / util.step_max)
    attention_w = tf.reshape(tf.nn.softmax(tf.stack(attention_w)), [1, util.step_max])

    outputs = tf.reshape(outputs[0], [util.step_max, -1])
    outputs = tf.matmul(attention_w, outputs)
    outputs = tf.reshape(outputs, [1, input_data.shape[2], input_data.shape[3], input_data.shape[4]])

    return outputs, attention_w


def cnn_decoder_layer(conv_lstm_out_c, filter, output_shape, strides):
    """

    :param conv_lstm_out_c:
    :param filter:
    :param output_shape:
    :param strides:
    :return:
    """

    deconv = tf.nn.conv2d_transpose(
        input=conv_lstm_out_c,
        filters=filter,
        output_shape=output_shape,
        strides=strides,
        padding="SAME")
    deconv = tf.nn.selu(deconv)
    return deconv


def cnn_decoder(lstm1_out, lstm2_out, lstm3_out, lstm4_out):
    d_filter4 = tensor_variable([2, 2, 128, 256], "d_filter4")
    dec4 = cnn_decoder_layer(lstm4_out, d_filter4, [1, 8, 8, 128], (1, 2, 2, 1))
    dec4_concat = tf.concat([dec4, lstm3_out], axis=3)

    d_filter3 = tensor_variable([2, 2, 64, 256], "d_filter3")
    dec3 = cnn_decoder_layer(dec4_concat, d_filter3, [1, 15, 15, 64], (1, 2, 2, 1))
    dec3_concat = tf.concat([dec3, lstm2_out], axis=3)

    d_filter2 = tensor_variable([3, 3, 32, 128], "d_filter2")
    dec2 = cnn_decoder_layer(dec3_concat, d_filter2, [1, 30, 30, 32], (1, 2, 2, 1))
    dec2_concat = tf.concat([dec2, lstm1_out], axis=3)

    d_filter1 = tensor_variable([3, 3, 3, 64], "d_filter1")
    dec1 = cnn_decoder_layer(dec2_concat, d_filter1, [1, 30, 30, 3], (1, 1, 1, 1))

    return dec1


def main():
    # Read dataset from file
    matrix_data_path = util.train_data_path + "train.npy"
    matrix_gt_1 = np.load(matrix_data_path)

    sess = tf.compat.v1.Session()
    data_input = tf.compat.v1.placeholder(tf.float32, [util.step_max, 30, 30, 3])

    # cnn encoder
    conv1_out, conv2_out, conv3_out, conv4_out = cnn_encoder(data_input)

    conv1_out = tf.reshape(conv1_out, [-1, 5, 30, 30, 32])
    conv2_out = tf.reshape(conv2_out, [-1, 5, 15, 15, 64])
    conv3_out = tf.reshape(conv3_out, [-1, 5, 8, 8, 128])
    conv4_out = tf.reshape(conv4_out, [-1, 5, 4, 4, 256])

    # lstm with attention
    conv1_lstm_attention_out, atten_weight_1 = cnn_lstm_attention_layer(conv1_out, 1)
    conv2_lstm_attention_out, atten_weight_2 = cnn_lstm_attention_layer(conv2_out, 2)
    conv3_lstm_attention_out, atten_weight_3 = cnn_lstm_attention_layer(conv3_out, 3)
    conv4_lstm_attention_out, atten_weight_4 = cnn_lstm_attention_layer(conv4_out, 4)

    # cnn decoder
    deconv_out = cnn_decoder(conv1_lstm_attention_out, conv2_lstm_attention_out, conv3_lstm_attention_out,
                             conv4_lstm_attention_out)
    # loss function: reconstruction error of last step matrix
    loss = tf.reduce_mean(input_tensor=tf.square(data_input[-1] - deconv_out))
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=util.learning_rate).minimize(loss)

    # variable initialization
    init = tf.compat.v1.global_variables_initializer()
    sess.run(init)

    # training
    for idx in range(util.train_start_id, util.train_end_id):
        matrix_gt = matrix_gt_1[idx - util.train_start_id]
        feed_dict = {data_input: np.asarray(matrix_gt)}
        a, loss_value = sess.run([optimizer, loss], feed_dict)
        print("mse of last train data: " + str(loss_value))

    # test
    # Read the data from test file.
    matrix_data_path = util.test_data_path + "test.npy"
    matrix_gt_1 = np.load(matrix_data_path)
    result_all = []
    for idx in range(util.test_start_id, util.test_end_id):
        matrix_gt = matrix_gt_1[idx - util.test_start_id]
        feed_dict = {data_input: np.asarray(matrix_gt)}
        result, loss_value = sess.run([deconv_out, loss], feed_dict)
        result_all.append(result)
        print("mse of last test data: " + str(loss_value))

    # Write the reconstructed matrix to the file
    reconstructed_path = util.reconstructed_data_path
    if not os.path.exists(reconstructed_path):
        os.makedirs(reconstructed_path)
    reconstructed_path = reconstructed_path + "test_reconstructed.npy"

    result_all = np.asarray(result_all).reshape((-1, 30, 30, 3))
    print(result_all.shape)
    np.save(reconstructed_path, result_all)


if __name__ == '__main__':
    main()


The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
mse of last train data: 0.00056191476
mse of last train data: 0.0008397814
mse of last train data: 0.0009788115
mse of last train data: 0.0015275143
mse of last train data: 0.0018960452
mse of last train data: 0.0015373339
mse of last train data: 0.0012186245
mse of last train data: 0.0009643246
mse of last train data: 0.00080712256
mse of last train data: 0.0006685078
mse of last train data: 0.00055669027
mse of last 

mse of last train data: 0.0013464914
mse of last train data: 0.0012574107
mse of last train data: 0.0011973963
mse of last train data: 0.0011483922
mse of last train data: 0.0010861779
mse of last train data: 0.0011279699
mse of last train data: 0.001181956
mse of last train data: 0.0011827709
mse of last train data: 0.0012015909
mse of last train data: 0.0012706444
mse of last train data: 0.0012969548
mse of last train data: 0.0013360518
mse of last train data: 0.0013549714
mse of last train data: 0.001344365
mse of last train data: 0.0013358138
mse of last train data: 0.0012971199
mse of last train data: 0.0012400298
mse of last train data: 0.0011351941
mse of last train data: 0.0011002468
mse of last train data: 0.0011178211
mse of last train data: 0.0011359766
mse of last train data: 0.001113067
mse of last train data: 0.0010521347
mse of last train data: 0.0010077746
mse of last train data: 0.0009265113
mse of last train data: 0.00095399655
mse of last train data: 0.0009676965
mse

mse of last train data: 0.0016329982
mse of last train data: 0.0015298986
mse of last train data: 0.0015420592
mse of last train data: 0.001491368
mse of last train data: 0.0014125028
mse of last train data: 0.0013982751
mse of last train data: 0.0014319657
mse of last train data: 0.0014147809
mse of last train data: 0.0013569839
mse of last train data: 0.0012757038
mse of last train data: 0.0011809135
mse of last train data: 0.0011647923
mse of last train data: 0.001151103
mse of last train data: 0.001145226
mse of last train data: 0.0011862615
mse of last train data: 0.001189354
mse of last train data: 0.001171359
mse of last train data: 0.0011854062
mse of last train data: 0.0011679702
mse of last train data: 0.0011461786
mse of last train data: 0.0011005491
mse of last train data: 0.0010598543
mse of last train data: 0.0010017945
mse of last train data: 0.0009647608
mse of last train data: 0.0009159361
mse of last train data: 0.0008956352
mse of last train data: 0.00088806055
mse o

mse of last train data: 0.0009541947
mse of last train data: 0.0010108207
mse of last train data: 0.0010364436
mse of last train data: 0.0009894947
mse of last train data: 0.00097643986
mse of last train data: 0.000925609
mse of last train data: 0.0009026852
mse of last train data: 0.00093772635
mse of last train data: 0.0009970642
mse of last train data: 0.0010504007
mse of last train data: 0.0011059847
mse of last train data: 0.0011673279
mse of last train data: 0.0011629892
mse of last train data: 0.0011685679
mse of last train data: 0.0012121707
mse of last train data: 0.0011760428
mse of last train data: 0.0011301434
mse of last train data: 0.0011292599
mse of last train data: 0.0012091325
mse of last train data: 0.0012273844
mse of last train data: 0.0012395179
mse of last train data: 0.0012739047
mse of last train data: 0.0013571507
mse of last train data: 0.0013433759
mse of last train data: 0.0013559854
mse of last train data: 0.0013994459
mse of last train data: 0.0014797584


mse of last test data: 0.04082334
mse of last test data: 0.040083185
mse of last test data: 0.03985819
mse of last test data: 0.039754167
mse of last test data: 0.039474037
mse of last test data: 0.039507117
mse of last test data: 0.039683733
mse of last test data: 0.039667007
mse of last test data: 0.03954234
mse of last test data: 0.03914306
mse of last test data: 0.0385736
mse of last test data: 0.038148172
mse of last test data: 0.037649598
mse of last test data: 0.037061617
mse of last test data: 0.036700785
mse of last test data: 0.036190856
mse of last test data: 0.035422195
mse of last test data: 0.034796197
mse of last test data: 0.034407858
mse of last test data: 0.034019623
mse of last test data: 0.033275053
mse of last test data: 0.03250243
mse of last test data: 0.031793736
mse of last test data: 0.03122516
mse of last test data: 0.030735435
mse of last test data: 0.030259106
mse of last test data: 0.030017732
mse of last test data: 0.02974167
mse of last test data: 0.0290

mse of last test data: 0.035031658
mse of last test data: 0.035490207
mse of last test data: 0.03572595
mse of last test data: 0.03601049
mse of last test data: 0.03597831
mse of last test data: 0.036039367
mse of last test data: 0.03608627
mse of last test data: 0.03606685
mse of last test data: 0.03624761
mse of last test data: 0.036428586
mse of last test data: 0.03631103
mse of last test data: 0.036340993
mse of last test data: 0.036536086
mse of last test data: 0.036383446
mse of last test data: 0.03621247
mse of last test data: 0.03625416
mse of last test data: 0.036143474
mse of last test data: 0.036067598
mse of last test data: 0.036154605
mse of last test data: 0.03604631
mse of last test data: 0.0359825
mse of last test data: 0.036142346
mse of last test data: 0.036077425
mse of last test data: 0.03606107
mse of last test data: 0.036258876
mse of last test data: 0.03632675
mse of last test data: 0.036120854
mse of last test data: 0.036069844
mse of last test data: 0.036016762

mse of last test data: 0.054605152
mse of last test data: 0.054154
mse of last test data: 0.05385637
mse of last test data: 0.0536435
mse of last test data: 0.05328147
mse of last test data: 0.053027265
mse of last test data: 0.053042434
mse of last test data: 0.05267731
mse of last test data: 0.052673582
mse of last test data: 0.05239616
mse of last test data: 0.051708102
mse of last test data: 0.051275477
mse of last test data: 0.05111995
mse of last test data: 0.050987016
mse of last test data: 0.050700914
mse of last test data: 0.050172016
mse of last test data: 0.049366377
mse of last test data: 0.048903815
mse of last test data: 0.04788856
mse of last test data: 0.0470797
mse of last test data: 0.046200998
mse of last test data: 0.045548897
mse of last test data: 0.044755854
mse of last test data: 0.04399403
mse of last test data: 0.04330141
mse of last test data: 0.042683505
mse of last test data: 0.04236841
mse of last test data: 0.04186626
mse of last test data: 0.0416881
mse 

mse of last test data: 0.07809034
mse of last test data: 0.07847417
mse of last test data: 0.079089314
mse of last test data: 0.07933486
mse of last test data: 0.079938
mse of last test data: 0.08079829
mse of last test data: 0.081314266
mse of last test data: 0.08216548
mse of last test data: 0.0826575
mse of last test data: 0.08339308
mse of last test data: 0.083519086
mse of last test data: 0.0840565
mse of last test data: 0.084472895
mse of last test data: 0.08464473
mse of last test data: 0.08453613
mse of last test data: 0.083544716
mse of last test data: 0.082723804
mse of last test data: 0.08207079
mse of last test data: 0.08183347
mse of last test data: 0.0817586
mse of last test data: 0.081749596
mse of last test data: 0.08135106
mse of last test data: 0.08085264
mse of last test data: 0.08009475
mse of last test data: 0.07935655
mse of last test data: 0.07896202
mse of last test data: 0.07839527
mse of last test data: 0.07789573
mse of last test data: 0.077722974
mse of last

mse of last test data: 0.061568096
mse of last test data: 0.062518425
mse of last test data: 0.06325057
mse of last test data: 0.06402895
mse of last test data: 0.06497191
mse of last test data: 0.06579888
mse of last test data: 0.067111544
mse of last test data: 0.068179786
mse of last test data: 0.06903188
mse of last test data: 0.07017574
mse of last test data: 0.07141715
mse of last test data: 0.07240259
mse of last test data: 0.07312126
mse of last test data: 0.07356385
mse of last test data: 0.07393713
mse of last test data: 0.07472873
mse of last test data: 0.07574809
mse of last test data: 0.07669301
mse of last test data: 0.0775595
mse of last test data: 0.078188
mse of last test data: 0.07941976
mse of last test data: 0.080385394
mse of last test data: 0.08167663
mse of last test data: 0.08269441
mse of last test data: 0.08372783
mse of last test data: 0.084634386
mse of last test data: 0.08485869
mse of last test data: 0.085374646
mse of last test data: 0.085740976
mse of la