In [1]:
import torch
device = torch.device("cpu")

import sys
import os

os.chdir(os.path.abspath(''))
sys.path.append(os.path.abspath(os.path.abspath('')))
sys.path.append(os.path.abspath(os.path.abspath('') + '/src'))

from utils import *
from layers import *
from models import *

import pandas as pd
import numpy as np

import seaborn as sns
from ing_theme_matplotlib import mpl_style
import matplotlib as mpl
import matplotlib.pyplot as plt

df_train_total = pd.read_csv("../data/df_train_total.csv")
df_test_total = pd.read_csv("../data/df_test_total.csv")
df_merged = pd.read_csv("../data/df_merged.csv")

train_conti_input, train_cate_input, train_future_input, train_label = generate_ts_data(df_train_total, df_merged)
test_conti_input, test_cate_input, test_future_input, test_label = generate_ts_data(df_test_total, df_merged)

eval_a, eval_b, eval_c, eval_d, eval_e = generate_eval_ts(df_test_total, df_merged, input_seq_len=48, tau=12)

In [2]:
eval_conti = torch.FloatTensor(eval_a)
eval_cate = torch.LongTensor(eval_b)
eval_future = torch.LongTensor(eval_c)
eval_label = torch.FloatTensor(eval_d)
eval_past_label = torch.FloatTensor(eval_e)

In [None]:
tft = TemporalFusionTransformer(
    d_model=30,
    d_embedding=5,
    cate_dims=[16, 32, 24],
    num_cv=16,
    seq_len=48,
    num_targets=1,
    tau=12,
    quantile=[0.1, 0.3, 0.5, 0.7, 0.9],
    dr=0.1,
    device=device
)
                
deepar = DeepAR(
        d_input=16, 
        d_embedding=3, 
        n_embedding=[16, 32, 24], 
        d_model=30, 
        num_targets=1, 
        n_layers=3,
        dr=0.1
    )

mqrnn = MQRnn(
        d_input=16,
        d_embedding=1,
        n_embedding=[16, 32, 24],
        d_model=5,
        tau=12,
        num_targets=1,
        num_quantiles=5,
        n_layers=3,
        dr=0.1
    )


adj = torch.tensor([[1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], 
                    [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
                    [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],
                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                    [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
                    [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]]
                ).float()

norm_adj = adj/adj.sum(dim=-1).unsqueeze(-1)
norm_adj = norm_adj.to(device)



ding = STALSTM(48, 16, 12, 5)

deng = HSDSTM(
    adj=norm_adj,
    input_size=16,
    seq_len=48,
    num_channels=[16, 16],
    node_dim=1,
    dropout=0.1,
    num_levels=3,
    tau=12,
    num_quantiles=5
    )


mqrnn.load_state_dict(torch.load("../assets/MQRnn.pth", map_location="cpu"))
deepar.load_state_dict(torch.load("../assets/DeepAR.pth", map_location='cpu'))
tft.load_state_dict(torch.load("../assets/TFT.pth", map_location="cpu"))
ding.load_state_dict(torch.load('../assets/STALSTM.pth', map_location='cpu'))
deng.load_state_dict(torch.load('../assets/HSDSTM.pth',  map_location="cpu"))

In [4]:
test_input_for_deng, _, _ = generate_eval_ts_for_deng(df_test_total, df_merged)

deng.eval()
deng_output = deng(torch.tensor(test_input_for_deng))

In [5]:
test_input_for_ding, _, _ = generate_eval_ts_for_ding(df_test_total, df_merged)

ding.eval()
ding_output, alpha, beta = ding(torch.tensor(test_input_for_ding))

In [6]:
sps = torch.tensor([ [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 
                    [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
                ).float()

In [None]:
instatran = InstaTran(
    d_model=10,
    d_embedding=3,
    cate_dims=[16, 32, 24],
    spatial_structure=sps,
    num_cv=16,
    seq_len=48,
    num_targets=1,
    tau=12,
    quantile=[0.1, 0.3, 0.5, 0.7, 0.9],
    dr=0.1,
    device=device
)

instatran_wo_sps = SpatialTemporalTransformer(
    d_model=30,
    d_embedding=5,
    cate_dims=[16, 32, 24],
    spatial_structure=None,
    num_cv=16,
    seq_len=48,
    num_targets=1,
    tau=12,
    quantile=[0.1, 0.3, 0.5, 0.7, 0.9],
    dr=0.1,
    device=device
)

instatran_parallel = SpatialTemporalParallelTransformer(
    d_model=30,
    d_embedding=5,
    cate_dims=[16, 32, 24],
    spatial_structure=None,
    num_cv=16,
    seq_len=48,
    num_targets=1,
    tau=12,
    quantile=[0.1, 0.3, 0.5, 0.7, 0.9],
    dr=0.1,
    device=device
)

instatran_wo_M_S = SpatialTemporalTransformer2(
    d_model=10,
    d_embedding=3,
    cate_dims=[16, 32, 24],
    spatial_structure=None,
    num_cv=16,
    seq_len=48,
    num_targets=1,
    tau=12,
    quantile=[0.1, 0.3, 0.5, 0.7, 0.9],
    dr=0.1,
    device=device
)


instatran_w_tft_decoder = SpatialTemporalTransformer(
    d_model=30,
    d_embedding=5,
    cate_dims=[16, 32, 24],
    spatial_structure=sps,
    num_cv=16,
    seq_len=48,
    num_targets=1,
    tau=12,
    quantile=[0.1, 0.3, 0.5, 0.7, 0.9],
    dr=0.1,
    device=device
)

In [None]:
instatran.load_state_dict(torch.load("../assets/InstaTran.pth", map_location='cpu'))
instatran_wo_sps.load_state_dict(torch.load("../assets/InstaTran_wo_sps.pth", map_location='cpu'))
instatran_parallel.load_state_dict(torch.load("../assets/InstaTran_parallel.pth", map_location='cpu'))
instatran_wo_M_S.load_state_dict(torch.load("../assets/InstaTran_wo_M_S.pth", map_location='cpu'))
instatran_w_tft_decoder.load_state_dict(torch.load("../assets/InstaTran_w_tft_decoder.pth", map_location='cpu'))

In [9]:
mpl.rcParams["figure.dpi"] = 60
mpl_style(dark=False)

In [10]:
eval_conti_sample = torch.FloatTensor(eval_a[1500:1650])
eval_cate_sample = torch.LongTensor(eval_b[1500:1650])
eval_future_sample = torch.LongTensor(eval_c[1500:1650])
eval_label_sample = torch.FloatTensor(eval_d[1500:1650])
eval_past_label_sample = torch.FloatTensor(eval_e[1500:1650])

instatran.eval()
output, ssa_weight1, ssa_weight2, tsa_weight, dec_weights, fi1, fi2 = instatran(eval_conti_sample, eval_cate_sample, eval_future_sample)

instatran_wo_sps.eval()
output_no_sps, ssa_weight1_no_sps, ssa_weight2_no_sps, tsa_weight_no_sps, dec_weights_no_sps, fi1_no_sps, fi2_no_sps = instatran_wo_sps(eval_conti_sample, eval_cate_sample, eval_future_sample)

#### Figure 4 (a) and (B)

In [None]:
# Dry
plt.matshow(ssa_weight2[10, 15, ...].detach().numpy(), cmap='Reds')
plt.colorbar()
plt.xlabel("Variable index")
plt.ylabel("Variable index")

plt.matshow(ssa_weight2_no_sps[10, 15, ...].detach().numpy(), cmap='Reds')
plt.colorbar()
plt.xlabel("Variable index")
plt.ylabel("Variable index")

#### Figure 4 (c) and (d)

In [None]:
# Rainy
plt.matshow(ssa_weight2[50, 0, ...].detach().numpy(), cmap='Reds')
plt.colorbar()
plt.xlabel("Variable index")
plt.ylabel("Variable index")

plt.matshow(ssa_weight2_no_sps[50, 0, ...].detach().numpy(), cmap='Reds')
plt.colorbar()
plt.xlabel("Variable index")
plt.ylabel("Variable index")

In [13]:
tft.eval()
confe_output = tft.confe(eval_conti_sample) 
catfe_output = tft.catfe(eval_cate_sample)  
obs_feature = torch.cat([confe_output, catfe_output], axis=-2)  
x1, tft_vsn_output  = tft.vsn1(obs_feature) 

In [14]:
instatran.eval()
instatran_output, ssa_weight1, ssa_weight2, tsa_weight, dec_weights, fi1, fi2  = instatran(eval_conti_sample, eval_cate_sample, eval_future_sample)

#### Figure 5 (a)

In [None]:
ax = sns.lineplot(eval_conti_sample.cpu()[::48, :, 0].squeeze().reshape(-1)[30:], label=r"Observation of $P_1$").set(xlabel="Time points")
sns.lineplot(fi1.detach().cpu()[::48, :, 0, 0].reshape(-1)[30:], linestyle='--', label=r"Importance of $P_1$ (InstaTran)")
sns.lineplot(tft_vsn_output.detach().cpu()[::48, :, 0, 0].reshape(-1)[30:], linestyle=':', label=r"Importance of $P_1$ (TFT)")
plt.show()

In [None]:
feature_idx = 8
importance_mat = np.zeros((150, 197))
for i in range(150):
    importance_mat[i, i:i+48] = fi1.detach().cpu()[i, :, feature_idx, 0]


tft_importance_mat = np.zeros((150, 197))
for i in range(150):
    tft_importance_mat[i, i:i+48] = tft_vsn_output.detach().cpu()[i, :, feature_idx, 0]

ax = sns.lineplot(eval_conti_sample.cpu()[::48, :, feature_idx].squeeze().reshape(-1)[30:], label=r"Observation of OF").set(xlabel="Time points")
sns.lineplot(np.nanmean(np.where(importance_mat==0.0, np.nan, importance_mat), axis=0)[30:-5], linestyle='--', label=r"Importance of OF (InstaTran)")
sns.lineplot(np.nanmean(np.where(tft_importance_mat==0.0, np.nan, tft_importance_mat), axis=0)[30:-5], linestyle=':', label=r"Importance of OF (TFT)")
plt.show()

In [17]:
eval_conti = torch.FloatTensor(eval_a)
eval_cate = torch.LongTensor(eval_b)
eval_future = torch.LongTensor(eval_c)
eval_label = torch.FloatTensor(eval_d)
eval_past_label = torch.FloatTensor(eval_e)

instatran_output, ssa_weight1, ssa_weight2, tsa_weight, dec_weights, fi1, fi2  = instatran(eval_conti, eval_cate, eval_future)

#### Figure 6 (a) and (b)

In [None]:
mpl_style(dark=False)
SMALL_SIZE = 10
MEDIUM_SIZE = 14
BIGGER_SIZE = 18

plt.rc('font', size=SMALL_SIZE)         
plt.rc('axes', titlesize=MEDIUM_SIZE)     
plt.rc('axes', labelsize=MEDIUM_SIZE)   
plt.rc('xtick', labelsize=MEDIUM_SIZE)  
plt.rc('ytick', labelsize=MEDIUM_SIZE)   
plt.rc('legend', fontsize=SMALL_SIZE)    
plt.rc('figure', titlesize=BIGGER_SIZE)  

g = sns.lineplot(np.quantile(dec_weights.detach().numpy()[:, 48, ], 0.5, axis=0), label=r"$k = 1$")
g.set_ylim(0, 0.029)
g.set_xticks([0, 12, 24, 36, 47, 59], ["-47", "-35", "-23", "-11", "0", "12"])
g.set_xlabel("Time points")
g.set_ylabel("Median of attention weights")
sns.lineplot(np.quantile(dec_weights.detach().numpy()[:, 49, ], 0.5, axis=0), label=r"$k = 2$")
sns.lineplot(np.quantile(dec_weights.detach().numpy()[:, 50, ], 0.5, axis=0), label=r"$k = 3$")
sns.lineplot(np.quantile(dec_weights.detach().numpy()[:, 51, ], 0.5, axis=0), label=r"$k = 4$")
sns.lineplot(np.quantile(dec_weights.detach().numpy()[:, 52, ], 0.5, axis=0), label=r"$k = 5$")
sns.lineplot(np.quantile(dec_weights.detach().numpy()[:, 53, ], 0.5, axis=0), label=r"$k = 6$")
g.axvline(47,  linestyle='--', linewidth=2, color='k')

xstart = 24
ystart = 0.026
g.annotate("",
            xy=(xstart, ystart),
            xytext=(xstart+12, ystart),
            va="center",
            ha="center",
            arrowprops=dict(color='black', arrowstyle="<->"))
g.annotate("Half-daily interval", xy=(xstart-4, ystart+0.002), xytext=(xstart-0.7, ystart+0.002), color='black')
g.annotate("(12 hours)", xy=(xstart, ystart+0.0005), xytext=(xstart+2.4, ystart+0.0005), color='black')
plt.show()

In [None]:
g = sns.lineplot(np.quantile(dec_weights.detach().numpy()[:, 54, :], 0.5, axis=0), label=r"$k = 7$")
g.set_ylim(0, 0.029)
g.set_xticks([0, 12, 24, 36, 47, 59], ["-47", "-35", "-23", "-11", "0", "12"])
g.set_xlabel("Time points")
g.set_ylabel("Median of attention weights")
sns.lineplot(np.quantile(dec_weights.detach().numpy()[:, 55, :], 0.5, axis=0), label=r"$k = 8$")
sns.lineplot(np.quantile(dec_weights.detach().numpy()[:, 56, :], 0.5, axis=0), label=r"$k = 9$")
sns.lineplot(np.quantile(dec_weights.detach().numpy()[:, 57, :], 0.5, axis=0), label=r"$k = 10$")
sns.lineplot(np.quantile(dec_weights.detach().numpy()[:, 58, :], 0.5, axis=0), label=r"$k = 11$")
sns.lineplot(np.quantile(dec_weights.detach().numpy()[:, 59, :], 0.5, axis=0), label=r"$k = 12$")
g.axvline(47,  linestyle='--', linewidth=2, color='k')
xstart = 28
ystart = 0.026
g.annotate("",
            xy=(xstart, ystart),
            xytext=(xstart+12, ystart),
            va="center",
            ha="center",
            arrowprops=dict(color='black', arrowstyle="<->"))
g.annotate("Half-daily interval", xy=(xstart-4, ystart+0.002), xytext=(xstart-0.7, ystart+0.002), color='black')
g.annotate("(12 hours)", xy=(xstart, ystart+0.0005), xytext=(xstart+2.4, ystart+0.0005), color='black')
plt.show()

#### Figure 6 (c) and (d)

In [20]:
tft.eval()
confe_output = tft.confe(eval_conti) 
catfe_output = tft.catfe(eval_cate)  

obs_feature = torch.cat([confe_output, catfe_output], axis=-2) 
x1, _  = tft.vsn1(obs_feature) 
future_embedding = tft.catfe(eval_future) 
x2, _ = tft.vsn2(future_embedding) 
delta, glu_phi, decoder_weights = tft.tfd(x1, x2) 

KeyboardInterrupt: 

In [None]:
g = sns.lineplot(np.quantile(decoder_weights.detach().numpy()[:, 48, ], 0.5, axis=0), label=r"$k = 1$")
g.set_xticks([0, 12, 24, 36, 47, 59], ["-47", "-35", "-23", "-11", "0", "12"])
g.set_xlabel("Time points")
g.set_ylabel("Median of Attention Weights")
sns.lineplot(np.quantile(decoder_weights.detach().numpy()[:, 49, ], 0.5, axis=0), label=r"$k = 2$")
sns.lineplot(np.quantile(decoder_weights.detach().numpy()[:, 50, ], 0.5, axis=0), label=r"$k = 3$")
sns.lineplot(np.quantile(decoder_weights.detach().numpy()[:, 51, ], 0.5, axis=0), label=r"$k = 4$")
sns.lineplot(np.quantile(decoder_weights.detach().numpy()[:, 52, ], 0.5, axis=0), label=r"$k = 5$")
sns.lineplot(np.quantile(decoder_weights.detach().numpy()[:, 53, ], 0.5, axis=0), label=r"$k = 6$")
plt.show()

In [None]:
g = sns.lineplot(np.quantile(decoder_weights.detach().numpy()[:, 54, ], 0.5, axis=0), label=r"$k = 7$")
g.set_xticks([0, 12, 24, 36, 47, 59], ["-47", "-35", "-23", "-11", "0", "12"])
g.set_xlabel("Time points")
g.set_ylabel("Median of Attention Weights")
sns.lineplot(np.quantile(decoder_weights.detach().numpy()[:, 55, ], 0.5, axis=0), label=r"$k = 8$")
sns.lineplot(np.quantile(decoder_weights.detach().numpy()[:, 56, ], 0.5, axis=0), label=r"$k = 9$")
sns.lineplot(np.quantile(decoder_weights.detach().numpy()[:, 57, ], 0.5, axis=0), label=r"$k = 10$")
sns.lineplot(np.quantile(decoder_weights.detach().numpy()[:, 58, ], 0.5, axis=0), label=r"$k = 11$")
sns.lineplot(np.quantile(decoder_weights.detach().numpy()[:, 59, ], 0.5, axis=0), label=r"$k = 12$")
plt.show()

#### Figure E.10 (a) and (b)

In [None]:
instatran_wo_M_S.eval()
output_no_sps, ssa_weight1_no_sps, ssa_weight2_no_sps, tsa_weight_no_sps, dec_weights_no_sps, fi1_no_sps, fi2_no_sps = instatran_wo_M_S(eval_conti, eval_cate, eval_future)

g = sns.lineplot(np.quantile(dec_weights_no_sps.detach().numpy()[:, 48, ], 0.5, axis=0), label=r"$k = 1$")
g.set_xticks([0, 12, 24, 36, 47, 59], ["-47", "-35", "-23", "-11", "0", "12"])
g.set_xlabel("Time points")
g.set_ylabel("Attention weights")
sns.lineplot(np.quantile(dec_weights_no_sps.detach().numpy()[:, 49, ], 0.5, axis=0), label=r"$k = 2$")
sns.lineplot(np.quantile(dec_weights_no_sps.detach().numpy()[:, 50, ], 0.5, axis=0), label=r"$k = 3$")
sns.lineplot(np.quantile(dec_weights_no_sps.detach().numpy()[:, 51, ], 0.5, axis=0), label=r"$k = 4$")
sns.lineplot(np.quantile(dec_weights_no_sps.detach().numpy()[:, 52, ], 0.5, axis=0), label=r"$k = 5$")
sns.lineplot(np.quantile(dec_weights_no_sps.detach().numpy()[:, 53, ], 0.5, axis=0), label=r"$k = 6$")
plt.show()

In [None]:
g = sns.lineplot(np.quantile(dec_weights_no_sps.detach().numpy()[:, 48, ], 0.5, axis=0), label=r"$k = 1$")
g.set_ylim(0, 0.029)
g.set_xticks([0, 12, 24, 36, 47, 59], ["-47", "-35", "-23", "-11", "0", "12"])
g.set_xlabel("Time points")
g.set_ylabel("Median of attention weights")
sns.lineplot(np.quantile(dec_weights_no_sps.detach().numpy()[:, 49, ], 0.5, axis=0), label=r"$k = 2$")
sns.lineplot(np.quantile(dec_weights_no_sps.detach().numpy()[:, 50, ], 0.5, axis=0), label=r"$k = 3$")
sns.lineplot(np.quantile(dec_weights_no_sps.detach().numpy()[:, 51, ], 0.5, axis=0), label=r"$k = 4$")
sns.lineplot(np.quantile(dec_weights_no_sps.detach().numpy()[:, 52, ], 0.5, axis=0), label=r"$k = 5$")
sns.lineplot(np.quantile(dec_weights_no_sps.detach().numpy()[:, 53, ], 0.5, axis=0), label=r"$k = 6$")
plt.show()

In [None]:
def plot_past_prediction_results2(true, preds, past, batch_num=0, dark=False):
    import seaborn as sns
    import pandas as pd
    from ing_theme_matplotlib import mpl_style 
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    
    mpl.rcParams["figure.dpi"] = 60
    mpl_style(dark=dark)
    SMALL_SIZE = 10
    MEDIUM_SIZE = 14
    BIGGER_SIZE = 18

    plt.rc('font', size=SMALL_SIZE)          
    plt.rc('axes', titlesize=MEDIUM_SIZE)     
    plt.rc('axes', labelsize=MEDIUM_SIZE)    
    plt.rc('xtick', labelsize=MEDIUM_SIZE)   
    plt.rc('ytick', labelsize=MEDIUM_SIZE)    
    plt.rc('legend', fontsize=MEDIUM_SIZE)    
    plt.rc('figure', titlesize=BIGGER_SIZE)  

    if type(preds) == torch.Tensor:    
        preds = preds.detach().cpu().numpy()
        true = true.cpu().numpy()[batch_num, ...]
        past = past.cpu().numpy() 
    
    else:
        true = true.cpu().numpy()[batch_num, ...]
        past = past.cpu().numpy()   
    
    site1_preds = preds[batch_num, :, 0, :]

    site1_past = past[batch_num, :, 0]

    df_site1 = pd.DataFrame({"10%": site1_preds[:, 0],
                             "90%": site1_preds[:, 4],
                             "Target": true[:, 0]}).reset_index().melt(id_vars=['index'])
   
    df_past_site1 = pd.DataFrame({"Observed": site1_past}).reset_index().melt(id_vars=['index'])
    df_past_site1['index'] = df_past_site1['index'].map(lambda x: x-48)
    df_past_site1 = pd.concat([df_past_site1, pd.DataFrame({"index": 0, "variable": "Observed", "value": true[0]})], axis=0)
    
    df_site1_past_pred = pd.concat([df_site1.loc[df_site1['variable'] == 'Target'], df_past_site1], axis=0)
    
    
    palette = {
        'Target': 'white' if dark else 'black',
        'Observed': 'tab:gray'
    }
    fig, ax = plt.subplots()
    line = sns.lineplot(ax=ax, x='index', y='value', hue='variable', data=df_site1_past_pred, palette=palette)
    conf = ax.fill_between(np.arange(12), df_site1.loc[df_site1['variable'] == '10%', 'value'], df_site1.loc[df_site1['variable'] == '90%', 'value'], color='blue', alpha=0.3, label=r'80% interval')
    ax.set(xlim=(-48, true.shape[0]), xlabel='Time points', ylabel='Water Level/1000')
    ax.legend(loc = 'upper left')
    plt.show()

#### Figure 7 (a) - (d)

In [None]:
instatran.eval()
stt_output, ssa_weight1, ssa_weight2, tsa_weight, dec_weights, fi1, fi2  = instatran(eval_conti_sample, eval_cate_sample, eval_future_sample)

deepar.eval()
output_deepar = deepar(eval_conti_sample, eval_cate_sample, eval_future_sample)
output_deepar_mu, output_deepar_sigma = output_deepar

deepar_output = gaussian_quantile(output_deepar_mu, output_deepar_sigma)

tft.eval()
tft_output = tft(eval_conti_sample, eval_cate_sample, eval_future_sample)

mqrnn.eval()
mqrnn_output = mqrnn(eval_conti_sample, eval_cate_sample, eval_future_sample)

In [None]:
# InstaTran
plot_past_prediction_results2(eval_label_sample, stt_output, eval_past_label_sample, batch_num=90)

In [None]:
# TFT
plot_past_prediction_results2(eval_label_sample, tft_output, eval_past_label_sample, batch_num=90)

In [None]:
# MQRNN
plot_past_prediction_results2(eval_label_sample, mqrnn_output, eval_past_label_sample, batch_num=90)

In [None]:
# DeepAR
plot_past_prediction_results2(eval_label_sample, deepar_output, eval_past_label_sample, batch_num=90)

#### Figure D.9

In [None]:
batch_num = 60
plt.matshow(alpha[batch_num, 0:1, ...].detach().numpy(), cmap='Reds')
plt.colorbar()
plt.xlabel("Variable index")
plt.yticks([])

batch_num = 1549
plt.matshow(alpha[batch_num, 1:2, :].detach().numpy(), cmap='Reds')
plt.colorbar()
plt.xlabel("Variable index")
plt.yticks([])
plt.show()

In [None]:
gtcn_output = deng.levels[0][0](torch.tensor(test_input_for_deng))

Wh = torch.matmul(gtcn_output, deng.levels[0][1].gat.W)
e = deng.levels[0][1].gat._prepare_attentional_mechanism_input(Wh)
zero_vec = -9e15*torch.ones_like(e)
attention = torch.where(deng.adj > 0, e, zero_vec)
alpha_ = F.softmax(attention, dim=-1)


# Dry
batch_num = 60
time_step = 0
plt.matshow(alpha_[batch_num, time_step, ...].detach().numpy(), cmap='Reds')
plt.colorbar()
plt.xlabel("Variable index")
plt.ylabel("Variable index")

# Rainy
batch_num = 1549
time_step = 1
plt.matshow(alpha_[batch_num, time_step, ...].detach().numpy(), cmap='Reds')
plt.colorbar()
plt.xlabel("Variable index")
plt.ylabel("Variable index")
plt.show()

#### Table 2

In [None]:
instatran.eval()
instatran_wo_M_S.eval()
instatran_parallel.eval()
instatran_w_tft_decoder.eval()

instatran_output, _, _, _, _, _, _  = instatran(eval_conti, eval_cate, eval_future)
wo_M_S_output, _, _, _, _, _, _  = instatran_wo_M_S(eval_conti, eval_cate, eval_future)
parallel_output, _, _, _, _, _, _  = instatran_parallel(eval_conti, eval_cate, eval_future)
tft_decoder_output, _, _, _, _, _, _  = instatran_w_tft_decoder(eval_conti, eval_cate, eval_future)


#### Table 2

In [None]:
# InstaTran QLs
print(torch.maximum(0.9 * (eval_label.squeeze() - instatran_output[..., 4].squeeze()), (1-0.9)*(instatran_output[..., 4].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
      torch.maximum(0.7 * (eval_label.squeeze() - instatran_output[..., 3].squeeze()), (1-0.7)*(instatran_output[..., 3].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
      torch.maximum(0.5 * (eval_label.squeeze() - instatran_output[..., 2].squeeze()), (1-0.5)*(instatran_output[..., 2].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4))

In [None]:
# InstaTran q-Rates, |q - q-Rate|
print(np.mean(eval_label.squeeze().detach().numpy() < instatran_output[..., 4].squeeze().detach().numpy()).round(3), 
      np.abs(0.9 - np.mean(eval_label.squeeze().detach().numpy() < instatran_output[..., 4].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < instatran_output[..., 3].squeeze().detach().numpy()).round(3), 
      np.abs(0.7 - np.mean(eval_label.squeeze().detach().numpy() < instatran_output[..., 3].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < instatran_output[..., 2].squeeze().detach().numpy()).round(3), 
      np.abs(0.5 - np.mean(eval_label.squeeze().detach().numpy() < instatran_output[..., 2].squeeze().detach().numpy())).round(3),
    )

In [None]:
# InstaTran without M_S QLs 
print(torch.maximum(0.9 * (eval_label.squeeze() - wo_M_S_output[..., 4].squeeze()), (1-0.9)*(wo_M_S_output[..., 4].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
      torch.maximum(0.7 * (eval_label.squeeze() - wo_M_S_output[..., 3].squeeze()), (1-0.7)*(wo_M_S_output[..., 3].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
      torch.maximum(0.5 * (eval_label.squeeze() - wo_M_S_output[..., 2].squeeze()), (1-0.5)*(wo_M_S_output[..., 2].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4))

In [None]:
# InstaTran without M_S q-Rates, |q - q-Rate|
print(np.mean(eval_label.squeeze().detach().numpy() < wo_M_S_output[..., 4].squeeze().detach().numpy()).round(3), 
      np.abs(0.9 - np.mean(eval_label.squeeze().detach().numpy() < wo_M_S_output[..., 4].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < wo_M_S_output[..., 3].squeeze().detach().numpy()).round(3), 
      np.abs(0.7 - np.mean(eval_label.squeeze().detach().numpy() < wo_M_S_output[..., 3].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < wo_M_S_output[..., 2].squeeze().detach().numpy()).round(3), 
      np.abs(0.5 - np.mean(eval_label.squeeze().detach().numpy() < wo_M_S_output[..., 2].squeeze().detach().numpy())).round(3),
    )

In [None]:
# InstaTran parallel attention QLs
print(torch.maximum(0.9 * (eval_label.squeeze() - parallel_output[..., 4].squeeze()), (1-0.9)*(parallel_output[..., 4].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
      torch.maximum(0.7 * (eval_label.squeeze() - parallel_output[..., 3].squeeze()), (1-0.7)*(parallel_output[..., 3].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
      torch.maximum(0.5 * (eval_label.squeeze() - parallel_output[..., 2].squeeze()), (1-0.5)*(parallel_output[..., 2].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4))

In [None]:
# InstaTran parallel attention q-Rates, |q - q-Rate|
print(np.mean(eval_label.squeeze().detach().numpy() < parallel_output[..., 4].squeeze().detach().numpy()).round(3), 
      np.abs(0.9 - np.mean(eval_label.squeeze().detach().numpy() < parallel_output[..., 4].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < parallel_output[..., 3].squeeze().detach().numpy()).round(3), 
      np.abs(0.7 - np.mean(eval_label.squeeze().detach().numpy() < parallel_output[..., 3].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < parallel_output[..., 2].squeeze().detach().numpy()).round(3), 
      np.abs(0.5 - np.mean(eval_label.squeeze().detach().numpy() < parallel_output[..., 2].squeeze().detach().numpy())).round(3),
    )

In [None]:
# InstaTran with TFT decoder QLs
print(torch.maximum(0.9 * (eval_label.squeeze() - tft_decoder_output[..., 4].squeeze()), (1-0.9)*(tft_decoder_output[..., 4].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
      torch.maximum(0.7 * (eval_label.squeeze() - tft_decoder_output[..., 3].squeeze()), (1-0.7)*(tft_decoder_output[..., 3].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
      torch.maximum(0.5 * (eval_label.squeeze() - tft_decoder_output[..., 2].squeeze()), (1-0.5)*(tft_decoder_output[..., 2].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4))

In [None]:
# InstaTran with TFT decoder q-Rates, |q - q-Rate|
print(np.mean(eval_label.squeeze().detach().numpy() < tft_decoder_output[..., 4].squeeze().detach().numpy()).round(3), 
      np.abs(0.9 - np.mean(eval_label.squeeze().detach().numpy() < tft_decoder_output[..., 4].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < tft_decoder_output[..., 3].squeeze().detach().numpy()).round(3), 
      np.abs(0.7 - np.mean(eval_label.squeeze().detach().numpy() < tft_decoder_output[..., 3].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < tft_decoder_output[..., 2].squeeze().detach().numpy()).round(3), 
      np.abs(0.5 - np.mean(eval_label.squeeze().detach().numpy() < tft_decoder_output[..., 2].squeeze().detach().numpy())).round(3),
    )

### Table 3

In [None]:
instatran.eval()
output, ssa_weight1, ssa_weight2, tsa_weight, dec_weights, fi1, fi2 = instatran(eval_conti, eval_cate, eval_future)

tft.eval()
confe_output = tft.confe(eval_conti) 
catfe_output = tft.catfe(eval_cate)  
obs_feature = torch.cat([confe_output, catfe_output], axis=-2)  
x1, tft_vsn_output  = tft.vsn1(obs_feature) 

#### Table 3 - InstaTran

In [None]:
# mean
fi1.detach().numpy().squeeze().reshape(-1, 16).mean(axis=0).round(3)

In [None]:
# std
fi1.detach().numpy().squeeze().reshape(-1, 16).std(axis=0).round(3)

In [None]:
# 0.1-quantile
np.quantile(fi1.detach().numpy().squeeze().reshape(-1, 16), 0.1, axis=0).round(3)

In [None]:
# 0.5-quantile
np.quantile(fi1.detach().numpy().squeeze().reshape(-1, 16), 0.5, axis=0).round(3)

In [None]:
# 0.9-qunatile
np.quantile(fi1.detach().numpy().squeeze().reshape(-1, 16), 0.9, axis=0).round(3)

#### Table 3 - TFT

In [None]:
# mean
tft_vsn_output.detach().numpy().squeeze()[..., :16].reshape(-1, 16).mean(axis=0).round(3)

In [None]:
# std
tft_vsn_output.detach().numpy().squeeze()[..., :16].reshape(-1, 16).std(axis=0).round(3)

In [None]:
# 0.1-quantile
np.quantile(tft_vsn_output.detach().numpy().squeeze()[..., :16].reshape(-1, 16), 0.1, axis=0).round(3)

In [None]:
# 0.5-quantile
np.quantile(tft_vsn_output.detach().numpy().squeeze()[..., :16].reshape(-1, 16), 0.5, axis=0).round(3)

In [None]:
# 0.9-quantile
np.quantile(tft_vsn_output.detach().numpy().squeeze()[..., :16].reshape(-1, 16), 0.9, axis=0).round(3)

#### Table 5 - Deep learning based models

In [None]:
deepar.eval()
output_deepar = deepar(eval_conti, eval_cate, eval_future)
output_deepar_mu, output_deepar_sigma = output_deepar
output_deepar_mu.detach().cpu().numpy()
output_deepar_mu.shape
deepar_output = gaussian_quantile(output_deepar_mu, output_deepar_sigma)

mqrnn.eval()
mqrnn_output = mqrnn(eval_conti, eval_cate, eval_future)

tft.eval()
tft_output = tft(eval_conti, eval_cate, eval_future)

In [None]:
# InstaTran QLs
print(torch.maximum(0.9 * (eval_label.squeeze() - instatran_output[..., 4].squeeze()), (1-0.9)*(instatran_output[..., 4].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
      torch.maximum(0.7 * (eval_label.squeeze() - instatran_output[..., 3].squeeze()), (1-0.7)*(instatran_output[..., 3].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
      torch.maximum(0.5 * (eval_label.squeeze() - instatran_output[..., 2].squeeze()), (1-0.5)*(instatran_output[..., 2].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4))

In [None]:
# InstaTran q-Rates, |q - q-Rate|
print(np.mean(eval_label.squeeze().detach().numpy() < instatran_output[..., 4].squeeze().detach().numpy()).round(3), 
      np.abs(0.9 - np.mean(eval_label.squeeze().detach().numpy() < instatran_output[..., 4].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < instatran_output[..., 3].squeeze().detach().numpy()).round(3), 
      np.abs(0.7 - np.mean(eval_label.squeeze().detach().numpy() < instatran_output[..., 3].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < instatran_output[..., 2].squeeze().detach().numpy()).round(3), 
      np.abs(0.5 - np.mean(eval_label.squeeze().detach().numpy() < instatran_output[..., 2].squeeze().detach().numpy())).round(3),
    )

In [None]:
# DeepAR QLs
print(
    torch.maximum(0.9 * (eval_label.squeeze() - torch.Tensor(deepar_output)[..., 4].squeeze()), (1-0.9)*(torch.Tensor(deepar_output)[..., 4].squeeze() - eval_label.squeeze() )).mean().detach().numpy().round(4),
    torch.maximum(0.7 * (eval_label.squeeze() - torch.Tensor(deepar_output)[..., 3].squeeze()), (1-0.7)*(torch.Tensor(deepar_output)[..., 3].squeeze() - eval_label.squeeze() )).mean().detach().numpy().round(4),
    torch.maximum(0.5 * (eval_label.squeeze() - torch.Tensor(deepar_output)[..., 2].squeeze()), (1-0.5)*(torch.Tensor(deepar_output)[..., 2].squeeze() - eval_label.squeeze() )).mean().detach().numpy().round(4)
)

In [None]:
# DeepAR q-Rates, |q - q-Rate|
print(np.mean(eval_label.squeeze().detach().numpy() < torch.Tensor(deepar_output)[..., 4].squeeze().detach().numpy()).round(3), 
      np.abs(0.9 - np.mean(eval_label.squeeze().detach().numpy() < torch.Tensor(deepar_output)[..., 4].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < torch.Tensor(deepar_output)[..., 3].squeeze().detach().numpy()).round(3), 
      np.abs(0.7 - np.mean(eval_label.squeeze().detach().numpy() < torch.Tensor(deepar_output)[..., 3].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < torch.Tensor(deepar_output)[..., 2].squeeze().detach().numpy()).round(3), 
      np.abs(0.5 - np.mean(eval_label.squeeze().detach().numpy() < torch.Tensor(deepar_output)[..., 2].squeeze().detach().numpy())).round(3),
    )

In [None]:
# MQRnn QLs
print(
    torch.maximum(0.9 * (eval_label.squeeze() - mqrnn_output[..., 4].squeeze()), (1-0.9)*(mqrnn_output[..., 4].squeeze() -eval_label.squeeze())).mean().detach().numpy().round(4),
    torch.maximum(0.7 * (eval_label.squeeze() - mqrnn_output[..., 3].squeeze()), (1-0.7)*(mqrnn_output[..., 3].squeeze() -eval_label.squeeze())).mean().detach().numpy().round(4),
    torch.maximum(0.5 * (eval_label.squeeze() - mqrnn_output[..., 2].squeeze()), (1-0.5)*(mqrnn_output[..., 2].squeeze() -eval_label.squeeze())).mean().detach().numpy().round(4),
)

In [None]:
# MQRnn q-Rates, |q - q-Rate|
print(np.mean(eval_label.squeeze().detach().numpy() < mqrnn_output[..., 4].squeeze().detach().numpy()).round(3), 
      np.abs(0.9 - np.mean(eval_label.squeeze().detach().numpy() < mqrnn_output[..., 4].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < mqrnn_output[..., 3].squeeze().detach().numpy()).round(3), 
      np.abs(0.7 - np.mean(eval_label.squeeze().detach().numpy() < mqrnn_output[..., 3].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < mqrnn_output[..., 2].squeeze().detach().numpy()).round(3), 
      np.abs(0.5 - np.mean(eval_label.squeeze().detach().numpy() < mqrnn_output[..., 2].squeeze().detach().numpy())).round(3),
    )

In [None]:
# TFT QLs
print(
    torch.maximum(0.9 * (eval_label.squeeze() - tft_output[..., 4].squeeze()), (1-0.9)*(tft_output[..., 4].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
    torch.maximum(0.7 * (eval_label.squeeze() - tft_output[..., 3].squeeze()), (1-0.7)*(tft_output[..., 3].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4), 
    torch.maximum(0.5 * (eval_label.squeeze() - tft_output[..., 2].squeeze()), (1-0.5)*(tft_output[..., 2].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4) 
)

In [None]:
# TFT q-Rates, |q - q-Rate|
print(np.mean(eval_label.squeeze().detach().numpy() < tft_output[..., 4].squeeze().detach().numpy()).round(3), 
      np.abs(0.9 - np.mean(eval_label.squeeze().detach().numpy() < tft_output[..., 4].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < tft_output[..., 3].squeeze().detach().numpy()).round(3), 
      np.abs(0.7 - np.mean(eval_label.squeeze().detach().numpy() < tft_output[..., 3].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < tft_output[..., 2].squeeze().detach().numpy()).round(3), 
      np.abs(0.5 - np.mean(eval_label.squeeze().detach().numpy() < tft_output[..., 2].squeeze().detach().numpy())).round(3),
    )

In [None]:
# HSDSTM QLs
print(
    torch.maximum(0.9 * (eval_label.squeeze() - deng_output[..., 4].squeeze()), (1-0.9)*(deng_output[..., 4].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
    torch.maximum(0.7 * (eval_label.squeeze() - deng_output[..., 3].squeeze()), (1-0.7)*(deng_output[..., 3].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4),
    torch.maximum(0.5 * (eval_label.squeeze() - deng_output[..., 2].squeeze()), (1-0.5)*(deng_output[..., 2].squeeze() -eval_label.squeeze() )).mean().detach().numpy().round(4)
)

In [None]:
# HSDSTM q-Rates, |q - q-Rate|
print(np.mean(eval_label.squeeze().detach().numpy() < deng_output[..., 4].squeeze().detach().numpy()).round(3), 
      np.abs(0.9 - np.mean(eval_label.squeeze().detach().numpy() < deng_output[..., 4].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < deng_output[..., 3].squeeze().detach().numpy()).round(3), 
      np.abs(0.7 - np.mean(eval_label.squeeze().detach().numpy() < deng_output[..., 3].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < deng_output[..., 2].squeeze().detach().numpy()).round(3), 
      np.abs(0.5 - np.mean(eval_label.squeeze().detach().numpy() < deng_output[..., 2].squeeze().detach().numpy())).round(3),
    )

In [None]:
# STA-LSTM QLs
print(
    torch.maximum(0.9 * (torch.tensor(eval_label).squeeze() - ding_output[..., 4].squeeze()), (1-0.9)*(ding_output[..., 4].squeeze() -torch.tensor(eval_label).squeeze() )).mean().detach().numpy().round(4),
    torch.maximum(0.7 * (torch.tensor(eval_label).squeeze() - ding_output[..., 3].squeeze()), (1-0.7)*(ding_output[..., 3].squeeze() -torch.tensor(eval_label).squeeze() )).mean().detach().numpy().round(4),
    torch.maximum(0.5 * (torch.tensor(eval_label).squeeze() - ding_output[..., 2].squeeze()), (1-0.5)*(ding_output[..., 2].squeeze() -torch.tensor(eval_label).squeeze() )).mean().detach().numpy().round(4)
)

In [None]:
# STA-LSTM q-Rates, |q - q-Rate|
print(np.mean(eval_label.squeeze().detach().numpy() < ding_output[..., 4].squeeze().detach().numpy()).round(3), 
      np.abs(0.9 - np.mean(eval_label.squeeze().detach().numpy() < ding_output[..., 4].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < ding_output[..., 3].squeeze().detach().numpy()).round(3), 
      np.abs(0.7 - np.mean(eval_label.squeeze().detach().numpy() < ding_output[..., 3].squeeze().detach().numpy())).round(3),
      np.mean(eval_label.squeeze().detach().numpy() < ding_output[..., 2].squeeze().detach().numpy()).round(3), 
      np.abs(0.5 - np.mean(eval_label.squeeze().detach().numpy() < ding_output[..., 2].squeeze().detach().numpy())).round(3),
    )