In [None]:
%matplotlib inline
import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt
import math

###

def plot(label, label_f, prediction, sigma_test, prediction_f, sigma_test_f, window_size, p_value, num_plot = 8, 
            y_range = None, y_range_f = None, forget_gate=None, forget_gate_f=None):
    x = np.arange(192)
    pred_minus_sigma = prediction - sigma_test
    pred_minus_sigma_f = prediction_f - sigma_test_f
    pred_plus_sigma = prediction + sigma_test
    pred_plus_sigma_f = prediction_f + sigma_test_f
    f = plt.figure(figsize=(50,32))

    base_1 = num_plot/2
    base_2 = 2

    for i in range(int(num_plot/2)):

        label_temp = label[i].reshape([window_size,])
        label_temp_f = label_f[i].reshape([window_size,])
        pred_temp = prediction[i].reshape([window_size,])
        pred_temp_f = prediction_f[i].reshape([window_size,])
        p_value_temp = p_value[i].reshape([window_size,])
        if forget_gate != None:
            forget_gate_temp = forget_gate[i].reshape([window_size,])
        if forget_gate_f != None:
            forget_gate_temp_f = forget_gate_f[i].reshape([window_size,])
                
        # no forget_gate

        plt.subplot(base_1, base_2, i*2+1)
        labels, = plt.plot(x,label_temp, color='b')
        preds, = plt.plot(x,pred_temp, color='r')
        stddev = plt.fill_between(x, pred_minus_sigma[i].reshape([window_size,]) ,
                                 pred_plus_sigma[i].reshape([window_size,])  ,
                                 color='blue',
                                 alpha=0.2)
        if forget_gate != None:
            forget_gates, = plt.plot(x, forget_gate_temp, color='g')
            plt.legend(handles = [labels, preds, stddev, forget_gates], labels = ['Ground_truth', 'Prediction', 'Single stddev ', 'forget_gate'], loc = 'upper left', fontsize =12)
        else:
            plt.legend(handles = [labels, preds, stddev], labels = ['Ground_truth', 'Prediction', 'Single stddev '], loc = 'upper left', fontsize =12)
        plt.axvline(168, color='k', linestyle = "dashed")
        #参考线
        ax = plt.gca()
        ax.tick_params(axis = 'x', which = 'major', labelsize = 24)
        ax.tick_params(axis = 'y', which = 'major', labelsize = 24)
        axes = plt.gca()
        if y_range:  axes.set_ylim(y_range)

        # forget_gate

        plt.subplot(base_1, base_2, i*2+2)
        labels, = plt.plot(x,label_temp_f, color='b')
        preds, = plt.plot(x,pred_temp_f, color='r')
        stddev_f = plt.fill_between(x, pred_minus_sigma_f[i].reshape([window_size,]) ,
                                 pred_plus_sigma_f[i].reshape([window_size,])  ,
                                 color='blue',
                                 alpha=0.2)
        if forget_gate_f != None:
            forget_gates, = plt.plot(x, forget_gate_temp_f, color='g')
            plt.legend(handles = [labels, preds, stddev_f, forget_gates], labels = ['Ground_truth', 'Prediction', 'Single stddev ', 'forget_gate'], loc = 'upper left', fontsize =12)
        else:
            plt.legend(handles = [labels, preds, stddev_f], labels = ['Ground_truth', 'Prediction', 'Single stddev '], loc = 'upper left', fontsize =12)
        plt.axvline(168, color='k', linestyle = "dashed")
        #参考线
        ax = plt.gca()
        ax.tick_params(axis = 'x', which = 'major', labelsize = 24)
        ax.tick_params(axis = 'y', which = 'major', labelsize = 24)
        axes = plt.gca()
        if y_range:  axes.set_ylim(y_range_f)

def compute_new_v(values):
    v = np.mean(values)+1
    return v
    
def rescale(values, v_old, v_new):
    values_temp = values*v_old
    return np.true_divide(values_temp, v_new)

def compute_forget_index(p_values): 
    #compute the index of the last change point
    index_list = []
    for serie in range(p_values.shape[0]):
        for i in reversed(range(p_values.shape[1])): #from 191 to 0
            if p_values[serie,i,0]==0: 
                index_list.append(i)
                break
            elif i==0:
                index_list.append(-1)
    index = np.array(index_list)
    return index #shape = [64,]
                
                
        
    



with tf.Session() as sess:
    checkpoint_dir = "../checkpoint/checkpoint_forget/"
    '''
    #CNN
    checkpoint_dir = "../checkpoint/checkpoint_CNN/"
    '''
    #ckpt包含所有checkpoint信息和最新checkpoint信息
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    print(ckpt)
    # ckpt.model_checkpoint_path是最新checkpoint的名字，加上".meta"即可用于导入graph
    if (ckpt and ckpt.model_checkpoint_path):
        #加载计算图
        saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path+".meta")
        #加载参数
        saver.restore(sess, ckpt.model_checkpoint_path)


    else:
        print("No model can be load!")

    train_op = tf.get_collection('train_op')[0]
    label = tf.get_collection("label")[0]

    miu_train = tf.get_collection("miu_train")[0]
    sigma_train = tf.get_collection("sigma_train")[0]
    RMSE_train = tf.get_collection("RMSE_train")[0]
    ND_train = tf.get_collection("ND_train")[0]

    miu_pred = tf.get_collection("miu_pred")[0]
    sigma_pred = tf.get_collection("sigma_pred")[0]
    RMSE_pred = tf.get_collection("RMSE_pred")[0]
    ND_pred = tf.get_collection("ND_pred")[0]
    hidden_states_all = tf.get_collection("hidden_states_all")[0] # shape: [2, batch_size, hidden_unit, window_size] = [2, 64, 40, 192]
    #forget_gate_all = tf.get_collection("forget_gate_all")[0] #shape: [batch_size, window_size=192, hidden_unit=40]




 

    def test_step(x_batch, onehot_batch, y_batch, v_batch, batch_size, forget_gate_mask):
        feed_dict ={
        'input_x:0': x_batch, 
        'input_onehot:0': onehot_batch,
        'input_y:0': y_batch,
        'input_v:0' :v_batch,
        'input_batch:0': batch_size,
        "forget_gate_mask:0": forget_gate_mask,
        }
        RMSE_test, ND_test, miu_test, sigma_test, hidden_states_all_test, = sess.run([RMSE_pred,
        #RMSE_test, ND_test, miu_test, sigma_test, hidden_states_all_test, forget_gate_all_test = sess.run([RMSE_pred,
                                                            ND_pred,
                                                            miu_pred,
                                                            sigma_pred,
                                                            hidden_states_all,
                                                            forget_gate_all,], 
                                                            feed_dict = feed_dict)
        

        
        return (RMSE_test, ND_test, miu_test, sigma_test, hidden_states_all_test)
        #return (RMSE_test, ND_test, miu_test, sigma_test, hidden_states_all_test, forget_gate_all)


    shift_train_data = np.load("../data/Electricity/elect_pre_train_data.npy")
    shift_train_onehot = np.load("../data/Electricity/elect_train_onehot.npy")
    v_all = np.load("../data/Electricity/elect_train_v.npy")
    shift_train_label = np.load("../data/Electricity/elect_train_label.npy")
    param = np.load("../data/Electricity/elect_train_param.npy")
    index_list = np.load("../data/Electricity/elect_train_index.npy")
    indexs_pred = np.load("../data/Electricity/elect_train_pred_index.npy")
    
    '''
    #huawei
    shift_train_data = np.load("../data/huawei/shift_train_data.npy")
    shift_train_onehot = np.load("../data/huawei/shift_train_onehot.npy")
    v_all = np.load("../data/huawei/v.npy")
    shift_train_label = np.load("../data/huawei/shift_train_label.npy")
    param = np.load("../data/huawei/param.npy")
    index_list = np.load("../data/huawei/indexs_list.npy")
    indexs_pred_list = np.load("../huawei/data/indexs_pred_list.npy")


    '''
    window_size = param[7]
    indexs_pred = [410688, 324890, 81043, 269732, 382203, 421931, 363429, 517402]
    indexs_pred = [i+1058 for i in range(64)]
    #indexs_pred = [i+498 for i in range(64)]

    input_x_batch_pred = shift_train_data[indexs_pred]
    input_onehot_batch_pred =shift_train_onehot[indexs_pred]
    input_v_batch_pred = v_all[indexs_pred]
    input_y_batch_pred = shift_train_label[indexs_pred]
    input_y_batch_pred = np.asfarray(input_y_batch_pred, float)
    batch_size_pred = len(indexs_pred)
    
    shift_train_pvalue = np.load("../data/Electricity/elect_train_p_value.npy") #[num_window_all, window_size] last 24 p are strictly 0
    shift_train_pvalue_raw = np.copy(shift_train_pvalue) #[num_window_all, window_size]
    shift_train_pvalue = np.abs(shift_train_pvalue-1) #shift 0&1
    shift_train_pvalue_raw = shift_train_pvalue_raw[indexs_pred]
    input_pvalue_batch_pred = shift_train_pvalue[indexs_pred] #[batch_size, window_size]
    input_pvalue_batch_pred = np.concatenate([np.expand_dims(input_pvalue_batch_pred, axis=2) for i in range(40)], axis=2) #[batch_size, window_size, 40]
    forget_gate_mask_sample = np.full((64, 192, 40), 1, dtype = np.float32) #all 1

    
    #input_x_batch_pred_period = input_x_batch_pred[2, 0:1, 0] # shape= [24,]
    #input_x_batch_pred[2, :, 0] = np.concatenate([input_x_batch_pred_period for i in range(192)], axis = 0) #[192, ]
    #input_x_batch_pred[2, 169:, 0]=0
    
#    RMSE_test, ND_test, miu_test, sigma_test, hidden_states_all_test, forget_gate_all_test = test_step(x_batch = input_x_batch_pred, 
    RMSE_test, ND_test, miu_test, sigma_test, hidden_states_all_test = test_step(x_batch = input_x_batch_pred, 
                                                        onehot_batch = input_onehot_batch_pred, 
                                                        y_batch = input_y_batch_pred, 
                                                        v_batch = input_v_batch_pred, 
                                                        batch_size = batch_size_pred,
                                                        forget_gate_mask = forget_gate_mask_sample, 
                                                        )
    
    input_y_batch_pred_copy = np.copy(input_y_batch_pred)
    miu_test_copy = miu_test
    sigma_test_copy = sigma_test
    
    input_y_batch_pred = np.true_divide(input_y_batch_pred, input_v_batch_pred)
    miu_test  = np.true_divide(miu_test, input_v_batch_pred)
    sigma_test  = np.true_divide(sigma_test, np.sqrt(input_v_batch_pred)) 
    #forget_gate_all_test = mean(forget_gate_all_test, axis=2) #[batch_size, window_length]
    
    ## WITH FORGET
    # rescale values after the last change point
    forget_index = compute_forget_index(input_pvalue_batch_pred) #[64,]
    new_v = []
    for serie in range(64):
        if forget_index[serie] != -1:
            index_temp = forget_index[serie]
            input_x_batch_pred_new = input_x_batch_pred[serie, index_temp+1:, :] #[new_length, 40]
            input_v_new = compute_new_v(input_x_batch_pred_new*input_v_batch_pred[serie])
            new_v.append(input_v_new)
            input_x_batch_pred_new = rescale(input_x_batch_pred_new, 
                                             v_old=input_v_batch_pred[serie], 
                                             v_new=input_v_new, )
            input_x_batch_pred[serie, index_temp+1:, :] = input_x_batch_pred_new
        else:
            new_v.append(input_v_batch_pred[serie])      
    new_v = np.array(new_v)
    new_v = np.expand_dims(new_v, axis=1) #[batch_size, 1]
    
    #RMSE_test_f, ND_test_f, miu_test_f, sigma_test_f, hidden_states_all_test_f, forget_gate_all_test_f = test_step(x_batch = input_x_batch_pred,     
    RMSE_test_f, ND_test_f, miu_test_f, sigma_test_f, hidden_states_all_test_f = test_step(x_batch = input_x_batch_pred, 
                                                        onehot_batch = input_onehot_batch_pred, 
                                                        y_batch = input_y_batch_pred, 
                                                        v_batch = new_v, 
                                                        batch_size = batch_size_pred,
                                                        forget_gate_mask = input_pvalue_batch_pred, 
                                                        )
    
    #miu_test_f  = np.true_divide(miu_test_f, new_v)
    #sigma_test_f  = np.true_divide(sigma_test_f, np.sqrt(new_v))
    #forget_gate_all_test_f = mean(forget_gate_all_test_f, axis=2) #[batch_size, window_length]

    
    
    
    
    
    plot(label = input_y_batch_pred, 
             label_f = input_y_batch_pred_copy,
             prediction = miu_test,
             sigma_test = sigma_test,
             prediction_f =miu_test_f , 
             sigma_test_f = sigma_test_f, 
             window_size = 192, 
             p_value = shift_train_pvalue_raw,
             num_plot = 16,
             y_range = [-2, 6],
#             y_range_f = [-2, 60],
             forget_gate=None, 
             forget_gate_f=None,
                )

