In [37]:
%matplotlib inline
import os
import pandas as pd
import re
import pickle as pkl
from utils.metrics import *

root_baseline_data_path = '../output/%s/%s/%s/'
seq_len = 96


def get_borders(data_name, data_len):
    border1s = [0, 12 * 30 * 24 * 4 - seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - seq_len]
    border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]

    if data_name in ['ETTm1', 'ETTm2']:
        return border1s, border2s

    num_train = int(data_len * 0.7)
    num_test = int(data_len * 0.2)
    num_vali = data_len - num_train - num_test
    border1s = [0, num_train - seq_len, data_len - num_test - seq_len]
    border2s = [num_train, num_train + num_vali, data_len]

    return border1s, border2s


def get_eb(exp_str):
    eb = re.findall('eb[_.0-9]+', exp_str)[0]
    eb = eb.replace('eb', '')
    eb = eb.replace('_', '')
    eb = float(eb)
    if eb >= 1:
        eb *= 0.01

    return eb


def load(path):
    with open(path, 'rb') as f:
        return pkl.load(f)


def metrics_ensemble(pred, true):
    mae = MAE(pred, true)
    rmse = RMSE(pred, true)
    rse = RSE(pred, true)
    nrmse = NRMSE(pred, true)
    corr = CORR(pred, true)
    psnr = PSNR(pred, true)

    return {'mae': mae,
            'rmse': rmse,
            'nrmse': nrmse,
            'rse': rse,
            'corr': corr,
            'psnr': psnr}


def load_pkl(path):
    with open(path, 'rb') as f:
        return pkl.load(f)


def get_baseline(model:str, data:str):
    raw_data_path = root_baseline_data_path%(model, data, 'raw') + 'testing_true.pickle'
    raw_data = load(raw_data_path)

    tmetrics = []
    ebs = []
    for eblc in ['pmc', 'swing', 'sz']:
        for root, dr, files in os.walk(root_baseline_data_path%(model, data, eblc)):
            for file in files:
                if 'eb_0_output' in file:
                    results = metrics_ensemble(load(root+'/'+ file), raw_data)
                    tmetrics.append(results)
                    ebs.append(0.0)

    df = pd.DataFrame(tmetrics)
    df['eb'] = ebs
    return df.groupby(['eb']).median()


def get_transformation_error(data_file: str, data_name: str, eblc_name: str, target_ot: str, ebs_values: tuple):
    df = pd.read_parquet(f'../data/compressed/{eblc_name}/{data_file}')
    tmetrics = []
    ebs = []
    border1s, border2s = get_borders(data_name, len(df))
    raw_df = df[f'{target_ot}-R'].values[border1s[2]:border2s[2]]
    for eb in ebs_values:
        decomp_target_var = df[[f'{target_ot}-E{eb}']].values[border1s[2]:border2s[2]][:, 0]
        results = metrics_ensemble(decomp_target_var, raw_df)
        tmetrics.append(results)
        ebs.append(eb)

    df = pd.DataFrame(tmetrics)
    df['eb'] = ebs
    return df

def get_forecasting_results(model: str, data_file: str, data_name: str,
                            eblc_name: str, target_ot: str, ebs_values: tuple):
    df = pd.read_parquet(f'../data/compressed/{eblc_name}/{data_file}')
    tmetrics = []
    ebs = []
    border1s, border2s = get_borders(data_name, len(df))
    raw_df = df[f'{target_ot}-R'].values[border1s[2]:border2s[2]]
    for eb in ebs_values:
        decomp_target_var = df[[f'{target_ot}-E{eb}']].values[border1s[2]:border2s[2]][:, 0]
        results = metrics_ensemble(decomp_target_var, raw_df)
        tmetrics.append(results)
        ebs.append(eb)

    df = pd.DataFrame(tmetrics)
    df['eb'] = ebs
    dec_error = df.groupby(['eb']).median()

    raw_data_path = root_baseline_data_path%(model, data_name, 'raw') + 'testing_true.pickle'
    raw_data = load(raw_data_path)

    tmetrics = []
    ebs = []
    for root, dr, files in os.walk(root_baseline_data_path%(model, data_name, eblc_name) + 'predictions'):
        for file in files:
            if 'eb_0_output' not in file:
                results = metrics_ensemble(load(root+'/'+ file), raw_data)
                tmetrics.append(results)
                ebs.append(get_eb(file))

    df = pd.DataFrame(tmetrics)
    df['eb'] = ebs
    fr = df.groupby(['eb']).median()
    fr.sort_index(inplace=True)
    return forecasting_results, dec_error


def concat_baseline_forecasting_result(baseline_results, forecasting_results, dec_error):
    concat_forecasting_results = pd.concat([baseline_results, forecasting_results], axis=0)
    metric_indexed_results = pd.DataFrame()
    metric_indexed_results['error'] = [0.0]+list(dec_error['nrmse'].values)
    metric_indexed_results['mae'] = concat_forecasting_results['mae'].values
    metric_indexed_results['rmse'] = concat_forecasting_results['rmse'].values
    metric_indexed_results['nrmse'] = concat_forecasting_results['nrmse'].values
    metric_indexed_results['rse'] = concat_forecasting_results['rse'].values
    metric_indexed_results['corr'] = concat_forecasting_results['corr'].values
    metric_indexed_results['data_corr'] = [1.0]+list(dec_error['corr'].values)
    metric_indexed_results.set_index('error', inplace=True)
    metric_indexed_results['eb'] = [0.0]+list(dec_error.index)
    return metric_indexed_results



In [43]:
target_variables_map = {'ettm1':'OT', 'ettm2': 'OT', 'aus_electrical_demand': 'y', 'weather': 'OT', 'wind': 'active power'}
bounds = [(0.01, 0.03, 0.05, 0.07, 0.10, 0.15, 0.20, 0.25, 0.30, 0.40, 0.50, 0.65, 0.8), (1, 3, 5, 7, 10, 15, 20, 25, 30, 40, 50, 65, 80)]
# for data in ['ettm1', 'ettm2', 'weather', 'aus_electrical_demand', 'solar']:
for m in ['gru', 'nbeats', 'transformer']:
    print('Baseline', m)
    display(get_baseline(m, 'aus'))
    all_results = pd.DataFrame()
    for data in ['aus_electrical_demand']:
        data_results = pd.DataFrame()
        for eblc in ['pmc', 'swing', 'sz']:
            print(m, eblc, data)
            if 'solar' in data:
                baseline_results, forecasting_results, dec_error = get_forecasting_results(model='arima', eblc_name=eblc, ebs_values=bounds[0] if eblc == 'sz' else bounds[1])
            elif 'aus' in data:
                baseline_results = get_baseline(m, 'aus')
                forecasting_results, dec_error = get_forecasting_results(model=m,
                                                                         data_file=f'{data}_points.parquet',
                                                                         data_name='aus',
                                                                         eblc_name=eblc,
                                                                         target_ot='y',
                                                                         ebs_values=bounds[0] if eblc == 'sz' else np.asarray(bounds[1])*1.0)
            else:
                baseline_results = get_baseline(m, data)
                forecasting_results, dec_error = get_forecasting_results(model=m,
                                                                         data_file=f'{data}_output_data_points.parquet',
                                                                         data_name=data,
                                                                         eblc_name=eblc,
                                                                         target_ot=target_variables_map[data] if eblc == 'sz' else target_variables_map[data].replace(' ', '_'),
                                                                         ebs_values=bounds[0] if eblc == 'sz' else bounds[1])
            concatenated = concat_baseline_forecasting_result(baseline_results, forecasting_results, dec_error)
            concatenated['eblc'] = eblc
            concatenated.at[0, 'eblc'] = 'baseline'
            data_results = pd.concat([data_results, concatenated])
            data_results.drop_duplicates(inplace=True)

        data_results['data'] = data
        all_results = pd.concat([all_results, data_results])
    all_results.to_csv(f'../results/tfe/{m}_results.csv')

Baseline gru


Unnamed: 0_level_0,mae,rmse,nrmse,rse,corr,psnr
eb,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0.0,0.607733,0.827525,0.134297,0.957168,0.581284,17.438654


gru pmc aus_electrical_demand
gru swing aus_electrical_demand
gru sz aus_electrical_demand
Baseline nbeats


Unnamed: 0_level_0,mae,rmse,nrmse,rse,corr,psnr
eb,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0.0,0.161433,0.2522,0.040929,0.29171,0.96059,27.759378


nbeats pmc aus_electrical_demand
nbeats swing aus_electrical_demand
nbeats sz aus_electrical_demand
Baseline transformer


Unnamed: 0_level_0,mae,rmse,nrmse,rse,corr,psnr
eb,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0.0,0.168923,0.258873,0.042012,0.299429,0.959474,27.532525


transformer pmc aus_electrical_demand
transformer swing aus_electrical_demand
transformer sz aus_electrical_demand
