In [52]:
import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl
import os
import pandas as pd

In [53]:
dataset_name = "M5"
methods = ["NBs_new", "DNNs_new"]

quantiles = np.array([0.5, 0.8, 0.9, 0.95, 0.99])

max_lag = {
    'carparts':44,
    'OnlineRetail':56,
    'Auto':16,
    'RAF':45,
    'M5':150
}

dlen = {
    'carparts':2489,
    'OnlineRetail':2023,
    'Auto':1227,
    'RAF':5000,
    'M5':29003
}

lags = np.arange(1,max_lag[dataset_name]) if dataset_name != "M5" else [1,2,3,4,5,6,7,8,9,10,15,20,25,30,50,80,100,150]

In [54]:
M = {}
for method in methods:
    M[method] = pkl.load(open(os.path.join(method+'_'+dataset_name+".pkl"), 'rb'))
M['iETS'] = pkl.load(open(os.path.join("iETS"+'_'+dataset_name+".pkl"), 'rb'))
M['EmpQ'] = pkl.load(open(os.path.join("EmpQ"+'_'+dataset_name+".pkl"), 'rb'))
M['TweedieGP'] = pkl.load(open(os.path.join("Tweedie"+'_'+dataset_name+".pkl"), 'rb'))

mask = np.load("mask_onlineretail.npy") if dataset_name == "OnlineRetail" else np.ones(dlen[dataset_name], dtype=bool)

assert M['iETS']['QL50'][mask].shape[0] == M['EmpQ']['QL50'][mask].shape[0] == M['TweedieGP']['QL50'][mask].shape[0] == dlen[dataset_name]

In [55]:
plt.rcParams.update({
    'axes.titlesize': 16,    # Title font size
    'axes.labelsize': 12,    # X and Y label font size
    'xtick.labelsize': 14,   # X tick label font size
    'ytick.labelsize': 14,   # Y tick label font size
    'legend.fontsize': 12    # Legend font size
})

In [None]:
min_lag = 0
Q = ['QL50', 'QL80', 'QL90', 'QL95', 'QL99']
labels = {
    'NBs_new':'glm.nb',
    'DNNs_new':'fnn.nb',
}
fig, axs = plt.subplots(1,len(Q)-1, figsize=(12,4), sharex=True, sharey=True)
colors = {
    'iETS':'#1E90FF',
    'EmpQ':'#4682B4',
    'TweedieGP':'magenta',
    'NBs_new':'red',
    'DNNs_new':'orange',
}
for ax, q_ in zip(axs, Q[1:]):
    ax.hlines(y=np.mean(M['EmpQ'][q_][mask]), xmin=min_lag+1, xmax=lags[-1], color=colors['EmpQ'], linestyle=':', label="EmpQ", lw=2)
    ax.hlines(y=np.mean(M['iETS'][q_][mask]), xmin=min_lag+1, xmax=lags[-1], color=colors['iETS'], linestyle='-', label="iETS", lw=2)
    ax.hlines(y=np.mean(M['TweedieGP'][q_][mask]), xmin=min_lag+1, xmax=lags[-1], color=colors['TweedieGP'], linestyle='-', label="TweedieGP", lw=2)
    for method in methods:
        tmp = np.array([np.mean(M[method][l][q_][mask]) for l in lags])
        tmp = pd.Series(tmp).rolling(window=3, min_periods=1, center=True).mean().tolist()
        ax.plot(lags[min_lag:], tmp[min_lag:], color=colors[method], label=labels[method], lw=2)
    ax.set_title(r"$\mathbf{sQL_{"+q_[2:]+"}}$", fontweight='bold')
    ax.set_xlabel('context length')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    #ax.set_facecolor('#FAFAFA')
    # ax.grid()
# axs[0].set_ylabel(dataset_name)
if dataset_name=="M5":
    axs[0].set_ylim(1,3.5)
elif dataset_name=="OnlineRetail":
    axs[0].set_ylim(2.1,6)
elif dataset_name=="RAF":
    axs[0].set_ylim(0.8,6)

if dataset_name in []:
    axs[0].legend(loc="lower left")
else:
    axs[0].legend(loc="upper left")
plt.tight_layout()
plt.savefig(os.path.join("images",dataset_name+'.svg'), format='svg')