In [1]:
import numpy as np
import pandas as pd
import os
from dataloader import load_raw, create_datasets
from measures import quantile_loss_sample, compute_intermittent_indicators
from sklearn import tree
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
from tqdm import tqdm

path = "/Users/nicolo.rubattu/switchdrive/iTS/trained_models"

In [None]:
actuals = np.load(os.path.join(path, "transformer__M5__tweedie__mean-demand__2024-07-01-08-37-15-418404/actuals.npy"))
forecasts_negbin = np.load(os.path.join(path, "transformer__M5__negbin__mean-demand__2024-06-30-03-26-19-805850/forecasts.npy"))
forecasts_tweedie = np.load(os.path.join(path, "transformer__M5__tweedie__mean-demand__2024-07-01-08-37-15-418404/forecasts.npy"))

actuals.shape, forecasts_negbin.shape, forecasts_tweedie.shape

In [3]:
def plot_forecasts(i, ts, back=14, r=3):
    _, axs = plt.subplots(2,14, figsize=(20,4), sharey=True)
    for h, ax in enumerate(axs.flatten()):
        ax.hist(np.round(forecasts_negbin[i,:,h]), color="tab:orange", bins=int(max(np.unique(np.round(forecasts_negbin[i,:,0])).size, np.unique(np.round(forecasts_tweedie[i,:,0])).size)/r), alpha=0.5)
        ax.hist(np.round(forecasts_tweedie[i,:,h]), color="m", bins=int(max(np.unique(np.round(forecasts_negbin[i,:,0])).size,np.unique(np.round(forecasts_tweedie[i,:,0])).size)/r), alpha=0.5)
        ax.text(actuals[i,h], -.5, "▲")
        ax.set_xlabel("h="+str(h+1))
        # ax.get_yaxis().set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        # ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
        plt.suptitle("Forecast densities (i="+str(i)+")")
    plt.tight_layout()
    plt.show()

    _, axs = plt.subplots(1,2, figsize=(20,4), sharey=True, sharex=True)
    axs[0].plot(np.append(ts[-back:], actuals[i]), color="black", label="y")
    axs[0].plot(np.arange(back,28+back), [np.quantile(x, 0.5) for x in forecasts_negbin[i,:].T], color="tab:orange", label="yhat")
    axs[0].plot(np.linspace(back-1,back), np.linspace(ts[-back:][-1], np.quantile(forecasts_negbin[i,:,0], 0.5)), color="tab:orange")
    axs[0].fill_between(np.arange(back,28+back), np.tile(0,28), [np.quantile(x, 0.50) for x in forecasts_negbin[i,:].T], color="tab:orange", alpha=0.1, label="QL50", edgecolor="none")
    axs[0].fill_between(np.linspace(back-1,back), np.tile(0,50), np.linspace(ts[-back:][-1], np.quantile(forecasts_negbin[i,:,0], 0.5)), color="tab:orange", alpha=0.1, label="QL50", edgecolor="none")
    axs[0].text(back+27.2, np.round(np.quantile(forecasts_negbin[i,:,-1], 0.50)), "QL50", fontsize=8)
    axs[0].fill_between(np.arange(back,28+back), np.tile(0,28), [np.quantile(x, 0.80) for x in forecasts_negbin[i,:].T], color="tab:orange", alpha=0.1, label="QL80", edgecolor="none")
    axs[0].fill_between(np.linspace(back-1,back), np.tile(0,50), np.linspace(ts[-back:][-1], np.quantile(forecasts_negbin[i,:,0], 0.8)), color="tab:orange", alpha=0.1, label="QL50", edgecolor="none")
    axs[0].text(back+27.2, np.round(np.quantile(forecasts_negbin[i,:,-1], 0.80)), "QL80", fontsize=8)
    axs[0].fill_between(np.arange(back,28+back), np.tile(0,28), [np.quantile(x, 0.90) for x in forecasts_negbin[i,:].T], color="tab:orange", alpha=0.1, label="QL90", edgecolor="none")
    axs[0].fill_between(np.linspace(back-1,back), np.tile(0,50), np.linspace(ts[-back:][-1], np.quantile(forecasts_negbin[i,:,0], 0.9)), color="tab:orange", alpha=0.1, label="QL50", edgecolor="none")
    axs[0].text(back+27.2, np.round(np.quantile(forecasts_negbin[i,:,-1], 0.90)), "QL90", fontsize=8)
    axs[0].fill_between(np.arange(back,28+back), np.tile(0,28), [np.quantile(x, 0.95) for x in forecasts_negbin[i,:].T], color="tab:orange", alpha=0.1, label="QL95", edgecolor="none")
    axs[0].fill_between(np.linspace(back-1,back), np.tile(0,50), np.linspace(ts[-back:][-1], np.quantile(forecasts_negbin[i,:,0], 0.95)), color="tab:orange", alpha=0.1, label="QL50", edgecolor="none")
    axs[0].text(back+27.2, np.round(np.quantile(forecasts_negbin[i,:,-1], 0.95)), "QL95", fontsize=8)
    axs[0].fill_between(np.arange(back,28+back), np.tile(0,28), [np.quantile(x, 0.99) for x in forecasts_negbin[i,:].T], color="tab:orange", alpha=0.1, label="QL99", edgecolor="none")
    axs[0].fill_between(np.linspace(back-1,back), np.tile(0,50), np.linspace(ts[-back:][-1], np.quantile(forecasts_negbin[i,:,0], 0.99)), color="tab:orange", alpha=0.1, label="QL50", edgecolor="none")
    axs[0].text(back+27.2, np.round(np.quantile(forecasts_negbin[i,:,-1], 0.99)), "QL99", fontsize=8)
    axs[0].set_title("negbin (i="+str(i)+")")
    axs[0].set_xticks(range(back+28))
    axs[0].set_xticklabels(back*[""] + [str(x) for x in list(range(1,29))])

    axs[1].plot(np.append(ts[-back:], actuals[i]), color="black", label="y")
    axs[1].plot(np.arange(back,28+back), [np.round(np.quantile(x, 0.5)) for x in forecasts_tweedie[i,:].T], color="m", label="yhat")
    axs[1].plot(np.linspace(back-1,back), np.linspace(ts[-back:][-1], np.round(np.quantile(forecasts_tweedie[i,:,0], 0.5))), color="m")
    axs[1].fill_between(np.arange(back,28+back), np.tile(0,28), [np.round(np.quantile(x, 0.50)) for x in forecasts_tweedie[i,:].T], color="m", alpha=0.1, label="QL50", edgecolor="none")
    axs[1].fill_between(np.linspace(back-1,back), np.tile(0,50), np.linspace(ts[-back:][-1], np.round(np.quantile(forecasts_tweedie[i,:,0], 0.5))), color="m", alpha=0.1, label="QL50", edgecolor="none")
    axs[1].text(back+27.2, np.round(np.quantile(forecasts_tweedie[i,:,-1], 0.50)), "QL50", fontsize=8)
    axs[1].fill_between(np.arange(back,28+back), np.tile(0,28), [np.round(np.quantile(x, 0.80)) for x in forecasts_tweedie[i,:].T], color="m", alpha=0.1, label="QL80", edgecolor="none")
    axs[1].fill_between(np.linspace(back-1,back), np.tile(0,50), np.linspace(ts[-back:][-1], np.round(np.quantile(forecasts_tweedie[i,:,0], 0.8))), color="m", alpha=0.1, label="QL50", edgecolor="none")
    axs[1].text(back+27.2, np.round(np.quantile(forecasts_tweedie[i,:,-1], 0.80)), "QL80", fontsize=8)
    axs[1].fill_between(np.arange(back,28+back), np.tile(0,28), [np.round(np.quantile(x, 0.90)) for x in forecasts_tweedie[i,:].T], color="m", alpha=0.1, label="QL90", edgecolor="none")
    axs[1].fill_between(np.linspace(back-1,back), np.tile(0,50), np.linspace(ts[-back:][-1], np.round(np.quantile(forecasts_tweedie[i,:,0], 0.9))), color="m", alpha=0.1, label="QL50", edgecolor="none")
    axs[1].text(back+27.2, np.round(np.quantile(forecasts_tweedie[i,:,-1], 0.90)), "QL90", fontsize=8)
    axs[1].fill_between(np.arange(back,28+back), np.tile(0,28), [np.round(np.quantile(x, 0.95)) for x in forecasts_tweedie[i,:].T], color="m", alpha=0.1, label="QL95", edgecolor="none")
    axs[1].fill_between(np.linspace(back-1,back), np.tile(0,50), np.linspace(ts[-back:][-1], np.round(np.quantile(forecasts_tweedie[i,:,0], 0.95))), color="m", alpha=0.1, label="QL50", edgecolor="none")
    axs[1].text(back+27.2, np.round(np.quantile(forecasts_tweedie[i,:,-1], 0.95)), "QL95", fontsize=8)
    axs[1].fill_between(np.arange(back,28+back), np.tile(0,28), [np.round(np.quantile(x, 0.99)) for x in forecasts_tweedie[i,:].T], color="m", alpha=0.1, label="QL99", edgecolor="none")
    axs[1].fill_between(np.linspace(back-1,back), np.tile(0,50), np.linspace(ts[-back:][-1], np.round(np.quantile(forecasts_tweedie[i,:,0], 0.99))), color="m", alpha=0.1, label="QL50", edgecolor="none")
    axs[1].text(back+27.2, np.round(np.quantile(forecasts_tweedie[i,:,-1], 0.99)), "QL99", fontsize=8)
    axs[1].set_title("tweedie (i="+str(i)+")")

    plt.tight_layout()
    plt.show()

In [None]:
quantiles = np.arange(0,1,0.025)
quantiles = np.append(quantiles, 0.99)
coverage = []
for samples, a in tqdm(zip(forecasts_negbin, actuals), total=forecasts_negbin.shape[0]):
    for samples_h, a_h in zip(samples.T, a):
        # if a_h == 0: continue
        coverage.append(a_h <= np.quantile(samples_h, quantiles))
coverage = np.stack(coverage)
predicted_quantiles_negbin = np.round(np.mean(coverage, axis=0), 3)
coverage = []
for samples, a in tqdm(zip(forecasts_tweedie, actuals), total=forecasts_tweedie.shape[0]):
    for samples_h, a_h in zip(samples.T, a):
        # if a_h == 0: continue
        coverage.append(a_h <= np.round(np.quantile(samples_h, quantiles)))
coverage = np.stack(coverage)
predicted_quantiles_tweedie = np.round(np.mean(coverage, axis=0), 3)

plt.figure(figsize=(8,5))
plt.plot(quantiles, quantiles, color="black")
# plt.scatter(quantiles, quantiles, color="black", s=12, lw=1)
plt.plot(quantiles, predicted_quantiles_negbin, label="negbin", color="tab:orange")
plt.scatter(quantiles, predicted_quantiles_negbin, color="tab:orange", s=12)
for x, y in zip(quantiles, predicted_quantiles_negbin):
    plt.hlines(y, xmin=quantiles[0], xmax=x, colors='tab:orange', linestyles='dashed', alpha=0.5, lw=0.3)
plt.plot(quantiles, predicted_quantiles_tweedie, label="tweedie", color="m")
plt.scatter(quantiles, predicted_quantiles_tweedie, color="m", s=12)
for x, y in zip(quantiles, predicted_quantiles_tweedie):
    plt.hlines(y, xmin=quantiles[0], xmax=x, colors='m', linestyles='dashed', alpha=0.5, lw=0.3)    
    
plt.axhline(np.mean(actuals==0), ls='dashed', c="gray", label="proportion of zeros")
plt.legend(loc="lower right")
plt.title('Coverage')
plt.xlabel('quantiles')
plt.xticks(np.round(quantiles,3), fontsize=8, rotation=90)
plt.ylabel('coverage')
plt.show()

In [None]:
quantiles = np.arange(0,1,0.025)
quantiles = np.append(quantiles, 0.99)
coverage = []
for samples, a in tqdm(zip(forecasts_negbin, actuals), total=forecasts_negbin.shape[0]):
    for samples_h, a_h in zip(samples.T, a):
        if a_h == 0: continue
        coverage.append(a_h <= np.quantile(samples_h, quantiles))
coverage = np.stack(coverage)
predicted_quantiles_negbin = np.round(np.mean(coverage, axis=0), 3)
coverage = []
for samples, a in tqdm(zip(forecasts_tweedie, actuals), total=forecasts_tweedie.shape[0]):
    for samples_h, a_h in zip(samples.T, a):
        if a_h == 0: continue
        coverage.append(a_h <= np.round(np.quantile(samples_h, quantiles)))
coverage = np.stack(coverage)
predicted_quantiles_tweedie = np.round(np.mean(coverage, axis=0), 3)

plt.figure(figsize=(8,5))
plt.plot(quantiles, quantiles, color="black")
# plt.scatter(quantiles, quantiles, color="black", s=12, lw=1)
plt.plot(quantiles, predicted_quantiles_negbin, label="negbin", color="tab:orange")
plt.scatter(quantiles, predicted_quantiles_negbin, color="tab:orange", s=12)
for x, y in zip(quantiles, predicted_quantiles_negbin):
    plt.hlines(y, xmin=quantiles[0], xmax=x, colors='tab:orange', linestyles='dashed', alpha=0.5, lw=0.3)
plt.plot(quantiles, predicted_quantiles_tweedie, label="tweedie", color="m")
plt.scatter(quantiles, predicted_quantiles_tweedie, color="m", s=12)
for x, y in zip(quantiles, predicted_quantiles_tweedie):
    plt.hlines(y, xmin=quantiles[0], xmax=x, colors='m', linestyles='dashed', alpha=0.5, lw=0.3)    
    
plt.axhline(np.mean(actuals==0), ls='dashed', c="gray", label="proportion of zeros")
plt.legend(loc="lower right")
plt.title('Coverage of demand')
plt.xlabel('quantiles')
plt.xticks(np.round(quantiles,3), fontsize=8, rotation=90)
plt.ylabel('coverage')
plt.show()

In [6]:
data_raw, data_info = load_raw(dataset_name='M5', datasets_folder_path=os.path.join("..", "data"))
datasets = create_datasets(data_raw, data_info)
adi, cv2 = compute_intermittent_indicators(data_raw, data_info['h'])
intermittent_mask = np.logical_and(adi >= 1.32, cv2 <= 0.49)

ql_negbin = quantile_loss_sample(actuals, forecasts_negbin, avg=False)
ql_tweedie = quantile_loss_sample(actuals, forecasts_tweedie, avg=False)

In [None]:
tmp = (np.mean(ql_negbin['QL90'], axis=1) - np.mean(ql_tweedie['QL90'], axis=1))
tmp_i = np.arange(200,210,1)
print("negbin is way better\t",   [(a,b) for a,b in zip(np.argsort(tmp)[tmp_i], intermittent_mask[np.argsort(tmp)[tmp_i]])])
print("tweedie is way better\t",  [(a,b) for a,b in zip(np.argsort(tmp)[::-1][tmp_i], intermittent_mask[np.argsort(tmp)[tmp_i]])])

In [None]:
i = 11687
plot_forecasts(i, back=100, ts=np.array(datasets['test'][i]['target'])[:-data_info['h']], r=1)

In [None]:
from measures import quantile_loss

baseline_path = os.path.join(os.path.expanduser("~/switchdrive"), "iTS", "trained_models_baselines")
baselines_name = [folder for folder in os.listdir(baseline_path) 
                  if os.path.isdir(os.path.join(baseline_path, folder)) and os.path.exists(os.path.join(baseline_path, folder, 'metrics.json'))]
baselines_name_sub = [x for x in baselines_name if x.split('__')[1] == "M5"]

subset="intermittent"
if subset == "intermittent":
    filter, filter_label = np.logical_and(adi >= 1.32, cv2 < .49), "intermittent"
elif subset == "intermittent_and_lumpy":
    filter, filter_label = adi >= 1.32, "intermittent_and_lumpy"
elif subset == "all":
    filter, filter_label = np.tile(True, adi.size), "all"

quantiles = [0.5,0.8,0.9,0.95, 0.99]
tmp = np.empty(shape=(len(datasets['test']), len(datasets['valid'][0]['target']), len(quantiles)))
for i in range(len(datasets['test'])):
    tmp[i, :] = np.round(np.quantile(datasets['valid'][i]['target'], q=quantiles))
res_base_scale_tmp = []
for i in range(len(datasets['test'])):
    res_base_scale_tmp.append(quantile_loss(np.array(datasets['valid'][i]['target']).reshape(1,-1), tmp[i].reshape(1,tmp[i].shape[0],tmp[i].shape[1]), quantiles, avg=False))
res_base_scale = {}
for q in ['QL50','QL80','QL90','QL95','QL99']:
    res_base_scale[q] = np.mean(np.vstack([res_base_scale_tmp[i][q] for i in range(len(datasets['test']))]), axis=1)[filter]

aggf=np.mean
scale=True
fscale = lambda x, q: x / res_base_scale[q][:, np.newaxis] if scale else x

fig, axs = plt.subplots(1, 5, figsize=(16,3))
plt.suptitle("Agg. scaled quantile loss")
for q, ax in zip(['QL50','QL80','QL90','QL95','QL99'], axs):
    ax.set_title(q)
    ax.set_xlabel('h')
    ax.plot(np.mean(fscale(ql_negbin[q][filter], q), axis=0), color="tab:orange")
    ax.text(27.5, np.mean(fscale(ql_negbin[q][filter], q), axis=0)[-1],"negbin", color="tab:orange")
for q, ax in zip(['QL50','QL80','QL90','QL95','QL99'], axs):
    ax.plot(np.mean(fscale(ql_tweedie[q][filter], q), axis=0), color="m")
    ax.text(27.5, np.mean(fscale(ql_tweedie[q][filter], q), axis=0)[-1],"tweedie", color="m")
plt.tight_layout()
plt.show()

In [None]:
subset="intermittent"
if subset == "intermittent":
    filter, filter_label = np.logical_and(adi >= 1.32, cv2 < .49), "intermittent"
elif subset == "intermittent_and_lumpy":
    filter, filter_label = adi >= 1.32, "intermittent_and_lumpy"
elif subset == "all":
    filter, filter_label = np.tile(True, adi.size), "all"

for q in ['QL50','QL80','QL90','QL95','QL99']:
    res_base_scale[q] = np.mean(np.vstack([res_base_scale_tmp[i][q] for i in range(len(datasets['test']))]), axis=1)
    
fscale = lambda x, q: x / res_base_scale[q][:, np.newaxis]

for q in ql_negbin.keys():
    print(q)
    print('   negbin zero  \t', np.round(np.sum(fscale(ql_negbin[q],q)[filter][actuals[filter]==0]),2))
    print('   negbin demand\t', np.round(np.sum(fscale(ql_negbin[q],q)[filter][actuals[filter]!=0]),2))
    print('   tweedie zero  \t', np.round(np.sum(fscale(ql_tweedie[q],q)[filter][actuals[filter]==0]),2))
    print('   tweedie demand\t', np.round(np.sum(fscale(ql_tweedie[q],q)[filter][actuals[filter]!=0]),2))

In [None]:
subset="intermittent_and_lumpy"
if subset == "intermittent":
    filter, filter_label = np.logical_and(adi >= 1.32, cv2 < .49), "intermittent"
elif subset == "intermittent_and_lumpy":
    filter, filter_label = adi >= 1.32, "intermittent_and_lumpy"
elif subset == "all":
    filter, filter_label = np.tile(True, adi.size), "all"

fig, axs = plt.subplots(1,5, figsize=(16,4), sharey=True)
for ax, q in zip(axs, ql_negbin.keys()):
    box = (ql_negbin[q] > ql_tweedie[q])*2
    box[ql_tweedie[q] == ql_negbin[q]] = 1
    ax.imshow(box[filter], vmin=0, vmax=2, cmap="bwr", aspect='auto')
    ax.set_title(q)

plt.tight_layout()

## Analysis...

In [12]:
def main(DSET, MODEL="deepAR", s=3, alpha=0.1):        
    entries = os.listdir(path)
    directories = [entry for entry in entries if os.path.isdir(os.path.join(path, entry))]
    directories_M5Trans = [entry for entry in directories if DSET in entry and MODEL in entry]

    data_raw, data_info = load_raw(dataset_name=DSET, datasets_folder_path=os.path.join("..", "data"))
    adi, cv2 = compute_intermittent_indicators(data_raw, data_info['h'])
    datasets = create_datasets(data_raw, data_info)

    context_width = data_info['h'] * data_info['w']

    res = []
    for m in ['negbin','tweedie']:
        directories_M5Trans_m = [entry for entry in directories_M5Trans if "__"+m+"__" in entry]
        for d in directories_M5Trans_m:
            actuals = np.load(os.path.join(path, d, "actuals.npy"))
            forecasts = np.load(os.path.join(path, d, "forecasts.npy"))
            ql = quantile_loss_sample(actuals, forecasts, quantiles=[0.5, 0.8, 0.9, 0.95, 0.99], avg=False)
            res.append( (m, actuals, ql) )
    res_df0 = pd.DataFrame(res, columns=['distr', 'actuals','quantile_loss'])

    actuals = res_df0.actuals.values[0]
    res = []
    for m in ['negbin','tweedie']:
        res_df_m = res_df0[res_df0.distr == m].quantile_loss.values
        for q in res_df_m[0].keys():
            ql_avg = np.mean(np.stack([x[q] for x in res_df_m]), axis=0)
            res.append( (m, q, ql_avg) )
    res_df = pd.DataFrame(res, columns=['distr', 'q','quantile_loss'])

    comb = []
    for i in range(len(datasets['test'])):
        a = actuals[i]
        for q in res_df.q.unique():
            ql_negbin = res_df[(res_df.distr=="negbin") & (res_df.q==q)].quantile_loss.values[0][i]
            ql_tweedie = res_df[(res_df.distr=="tweedie") & (res_df.q==q)].quantile_loss.values[0][i]
            condw = np.array(datasets['test'][i]['target'][:-data_info['h']][-context_width:])
            cond_ts = np.array(datasets['test'][i]['target'][:-data_info['h']])
            comb.append( (a, q, 
                        np.mean(ql_negbin), np.mean(ql_tweedie), 
                        adi[i], cv2[i], 
                        np.mean(condw), np.median(condw), 
                        np.mean(condw==0), np.median(condw[condw>0]) if np.sum(condw>0) else 1,
                        np.mean(cond_ts[cond_ts>0])) )
    comb = pd.DataFrame(comb, columns=['actual', 'q','ql_negbin', 'ql_tweedie', 'adi', 'cv2', 'condw_mean', 'condw_median', 'condw_0s', 'condw_meand', 'ts_meand'])

    def target_f(n,t):
        if n < t: return(-1)    # negbin wins
        if n == t: return(0)    # tie
        if n > t: return(1)     # tweedie wins

    comb['target'] = comb.apply(lambda row: target_f(row['ql_negbin'], row['ql_tweedie']), axis=1)

    def tree_to_code(t, feature_names, ax, xmin, xmax, ymin, ymax, lw=0.5):
        tree_ = t.tree_
        feature_name = [
            feature_names[i] if i != tree._tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        def recurse(node, depth):
            indent = "  " * depth
            if tree_.feature[node] != tree._tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                #print(f"{indent}if {name} <= {threshold}:")
                if np.where(feature_names == name)[0][0] == 0: # x axis
                    ax.vlines(threshold, ymin, ymax, lw=lw, color="black")
                    ax.text(threshold, ymax, str(round(threshold,2)), fontsize=6)
                else:
                    ax.hlines(threshold, xmin, xmax, lw=lw, color="black")
                    ax.text(xmax, threshold, str(round(threshold,2)), fontsize=6)
                recurse(tree_.children_left[node], depth + 1)
                #print(f"{indent}else:  # if {name} > {threshold}")
                if np.where(feature_names == name)[0][0] == 0: # x axis
                    ax.vlines(threshold, ymin, ymax, lw=lw, color="black")
                else:
                    ax.hlines(threshold, xmin, xmax, lw=lw, color="black")
                recurse(tree_.children_right[node], depth + 1)
            else:
                pass
                #print(f"{indent}return {np.argmax(tree_.value[node])}")
        recurse(0, 0)

    for q in comb.q.unique():
        #X = comb[comb.q == q].drop(columns=['actual', 'q','ql_negbin','ql_tweedie', 'condw_mean', 'target'])
        X = comb[comb.q == q][['adi','cv2']]
        #print(X.columns)
        Y = comb[comb.q == q]['target']

        clf = tree.DecisionTreeClassifier(max_depth=4, min_samples_leaf=int(0.03*len(datasets['test'])))
        clf.fit(X, Y)
        y_pred = clf.predict(X)

        print("Accuracy:", accuracy_score(Y, y_pred))
        print("Confusion matrix:\n", confusion_matrix(Y, y_pred))

        def getcolor(x):
            if x ==-1: return "orange"
            if x==0:   return "green"
            if x==1:   return "purple"

        fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(12,clf.get_depth()*3))
        p = tree.plot_tree(clf, filled=True, feature_names=X.columns, class_names=["negbin", "tie", "tweedie"], rounded=True, ax=ax[0])
        df = pd.DataFrame(clf.tree_.__getstate__()['nodes'])
        leaf_indices = clf.apply(X)
        unique_leaf_nodes = np.unique(leaf_indices)
        node_sample_indices = {node: np.where(leaf_indices == node)[0] for node in unique_leaf_nodes}
        intermittent_leaves = []
        for k in node_sample_indices.keys():
            tmp = X.iloc[node_sample_indices[k]]
            c = {}
            c["intermittent"] = np.sum(np.logical_and(tmp.adi >= 1.32, tmp.cv2 <= 0.49))
            c["smooth"] = np.sum(np.logical_and(tmp.adi < 1.32, tmp.cv2 <= 0.49))
            c["erratic"] = np.sum(np.logical_and(tmp.adi < 1.32, tmp.cv2 > 0.49))
            c["lumpy"] = np.sum(np.logical_and(tmp.adi >= 1.32, tmp.cv2 > 0.49))
            sorted_c = dict(sorted(c.items(), key=lambda item: item[1], reverse=True))
            print(sorted_c)
            intermittent_leaves.append(list(sorted_c.keys())[0] == "intermittent")
        for i in df[df.threshold == -2].index[intermittent_leaves]:
            rect = p[np.where(np.array([x.get_bbox_patch() for x in p]) != None)[0][i]].get_bbox_patch()
            rect.set_edgecolor("red")
            rect.set_linewidth(2)
        plt.title(r"{0}, {1}, {2}".format(DSET, MODEL, q), fontsize=12)
        ax[1].scatter(X.values[:,0], X.values[:,1], c=[getcolor(x) for x in Y.values], s=s, alpha=alpha)
        xmin, xmax = np.min(X.values[:,0]), np.max(X.values[:,0])
        ymin, ymax = np.min(X.values[:,1]), np.max(X.values[:,1])
        tree_to_code(clf, X.columns, ax[1], xmin, xmax, ymin, ymax)
        ax[1].axhline(0.49, 0.045, 0.955, c="red", lw=1, ls='--')
        ax[1].axvline(1.32, 0, 0.95, c="red", lw=1, ls='--')
        ax[1].set_xscale('log')
        ax[1].set_yscale('log')
        ax[1].set_xlabel(X.columns[0])
        ax[1].set_ylabel(X.columns[1])
        ax[1].set_xticklabels([])
        ax[1].set_yticklabels([])
        plt.show()

In [None]:
main("M5", MODEL="transformer")

In [None]:
main("OnlineRetail", MODEL="deepAR", s=6, alpha=0.5)

In [None]:
main("carparts", MODEL="deepAR", s=6, alpha=0.5)

In [None]:
main("RAF", MODEL="deepAR", s=6, alpha=0.5)

In [None]:
main("Auto", MODEL="deepAR", s=6, alpha=0.5)