In [1]:
import sys
import copy
import pickle
import numpy as np
import pandas as pd

from tqdm import tqdm
from pathlib import Path
from IPython.display import display

sys.path.append(str(Path.cwd().parent))
from src.data import DataBundle

HOME = Path('..').absolute()

Reproduce Table 1 with more metrics.

In [2]:
def mean_absolute_error(y_true, y_pred):
    return np.mean(np.abs(y_true - y_pred))

def mean_squared_error(y_true, y_pred):
    return np.mean((y_true - y_pred) ** 2)

def root_mean_squared_error(y_true, y_pred):
    return np.sqrt(mean_squared_error(y_true, y_pred))

def mean_squared_logarithmic_error(y_true, y_pred):
    return np.mean((np.log1p(y_true) - np.log1p(y_pred)) ** 2)

def root_mean_squared_logarithmic_error(y_true, y_pred):
    return np.sqrt(mean_squared_logarithmic_error(y_true, y_pred))

def r_squared(y_true, y_pred):
    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
    return 1 - ss_res / ss_tot

def adjusted_r_squared(y_true, y_pred, n=None, p=1):
    r2 = r_squared(y_true, y_pred)
    n = n or len(y_true)
    return 1 - ((1 - r2) * (n - 1) / (n - p - 1))

def mean_absolute_percentage_error(y_true, y_pred):
    return np.mean(np.abs((y_true - y_pred) / y_true)) * 100

def symmetric_mean_absolute_percentage_error(y_true, y_pred):
    return 100 * np.mean(np.abs(y_true - y_pred) / ((np.abs(y_true) + np.abs(y_pred)) / 2))

def median_absolute_error(y_true, y_pred):
    return np.median(np.abs(y_true - y_pred))

def explained_variance_score(y_true, y_pred):
    var_y_true = np.var(y_true)
    return 1 - (np.var(y_true - y_pred) / var_y_true)

def normalized_mean_absolute_error(y_true, y_pred):
    return mean_absolute_error(y_true, y_pred) / (np.max(y_true) - np.min(y_true)) * 100

def normalized_mean_squared_error(y_true, y_pred):
    return mean_squared_error(y_true, y_pred) / (np.max(y_true) - np.min(y_true))

def normalized_root_mean_squared_error(y_true, y_pred):
    return root_mean_squared_error(y_true, y_pred) / (np.max(y_true) - np.min(y_true)) * 100

metrics = {
    "Root Mean Squared Error": root_mean_squared_error,
    "Mean Absolute Error": mean_absolute_error,
    "Median Absolute Error": median_absolute_error,
    # "Mean Squared Error": mean_squared_error,
    # "Mean Squared Logarithmic Error": mean_squared_logarithmic_error,
    # "Root Mean Squared Logarithmic Error": root_mean_squared_logarithmic_error,
    # "R-squared": r_squared,
    # "Adjusted R-squared": adjusted_r_squared,
    # "Explained Variance Score": explained_variance_score,
    "Mean Absolute Percentage Error": mean_absolute_percentage_error,
    "Symmetric Mean Absolute Percentage Error": symmetric_mean_absolute_percentage_error,
    "Normalized Mean Absolute Error": normalized_mean_absolute_error,
    "Normalized Root Mean Squared Error": normalized_root_mean_squared_error
}

In [3]:
def format_scores(scores):
    scores = copy.deepcopy(scores)
    for name, s in scores.items():
        mean, std = np.mean(s), np.std(s)
        scores[name] = f'{mean:.0f}±{std:.0f}'
    return scores


def calc_scores(folder: str, metrics: dict) -> dict:
    # Load predictions
    predictions = []
    for data in (HOME / folder).glob('predictions*'):
        with open(data, 'rb') as f:
            predictions.append(pickle.load(f))

    # Restore the y_true and y_pred
    targets = []
    for pred  in predictions:
        data: DataBundle = pred['data']
        if data.label_transformation is not None:
            y_true = data.label_transformation.inverse_transform(data.test_data.label).cpu().numpy()
            y_pred = data.label_transformation.inverse_transform(pred['prediction'].cpu()).numpy()
        else:
            y_true = data.test_data.label.cpu().numpy()
            y_pred = pred['prediction'].cpu().numpy()
        targets.append((y_true, y_pred))

    # Get scores
    scores = {}
    for name, func in metrics.items():
        scores[name] = []
        for y_true, y_pred in targets:
            scores[name].append(func(y_true, y_pred))
    
    scores = format_scores(scores)
    return scores

In [4]:
datasets = {
    'matr_1': 'MATR-1',
    'matr_2': 'MATR-2',
    'hust': 'HUST',
    'mix_100': 'MIX-100',
    'mix_20': 'MIX-20'
}
sklearn_baselines = {
    'dummy': 'Training Mean',
    'variance_model': '``Variance\'\' Model',
    'discharge_model': '``Discharge\'\' Model',
    'full_model': '``Full\'\' Model',
    'ridge': 'Ridge Regression',
    'pcr': 'PCR',
    'plsr': 'PLSR',
    'svm': 'SVM',
    'rf': 'Random Forest'
}
nn_baselines = {
    'cnn': 'CNN',
    'mlp': 'MLP',
    'lstm': 'LSTM'
}

scores = {}
for dataset, dataset_name in datasets.items():
    scores[dataset_name] = {}
    # sklearn baselines
    for method, method_name in tqdm(sklearn_baselines.items(), desc='sklearn baselines'):
        folder = Path('workspaces/baselines/sklearn') / method / dataset
        scores[dataset_name][method_name] = calc_scores(folder, metrics)

    # nn baselines
    for method, method_name in tqdm(nn_baselines.items(), desc='nn baselines'):
        folder = Path('workspaces/baselines/nn_models') / method / dataset
        scores[dataset_name][method_name] = calc_scores(folder, metrics)

    # Ours
    folder = Path('workspaces/ablation/feature_spaces/all_features') / dataset
    scores[dataset_name]['BatLiNet'] = calc_scores(folder, metrics)

sklearn baselines: 100%|██████████| 9/9 [00:00<00:00, 23.77it/s]
nn baselines: 100%|██████████| 3/3 [00:01<00:00,  2.69it/s]
sklearn baselines: 100%|██████████| 9/9 [00:00<00:00, 180.68it/s]
nn baselines: 100%|██████████| 3/3 [00:00<00:00, 12.86it/s]
sklearn baselines: 100%|██████████| 9/9 [00:00<00:00, 73.85it/s]
nn baselines: 100%|██████████| 3/3 [00:00<00:00,  9.45it/s]
sklearn baselines: 100%|██████████| 9/9 [00:00<00:00, 403.89it/s]
nn baselines: 100%|██████████| 3/3 [00:02<00:00,  1.36it/s]
sklearn baselines: 100%|██████████| 9/9 [00:00<00:00, 370.86it/s]
nn baselines: 100%|██████████| 3/3 [00:00<00:00, 14.53it/s]


In [5]:
# Function to extract the mean from the string
def extract_mean(value):
    try:
        return float(value.split('±')[0])
    except:
        return np.inf

# Function to highlight the minimum mean value in each row
def highlight_min(s):
    # Extract means for the row
    means = s.apply(extract_mean)
    # Get the index of the minimum mean
    min_index = means.idxmin()
    # Create a series to apply the highlighting
    is_min = pd.Series([''] * len(s), index=s.index)
    is_min[min_index] = 'background-color: black'
    return is_min

metric_abbr = [
    'RMSE', 'MAE', 'MAD',
    'MAPE', 'sMAPE', 'NMAE', 'NRMSE'
]
for dataset, data_scores in scores.items():
    print(dataset)
    # TODO: fill in NE results
    data_scores = pd.DataFrame(data_scores)
    if dataset == 'MATR-1':
        data_scores.loc[:, ['``Discharge\'\' Model', '``Full\'\' Model']] = '-'
        data_scores.loc[
            ['Root Mean Squared Error', 'Mean Absolute Percentage Error'],
            '``Discharge\'\' Model'
        ] = ['86±0', '10±0']
        data_scores.loc[
            ['Root Mean Squared Error', 'Mean Absolute Percentage Error'],
            '``Full\'\' Model'
        ] = ['100±0', '8±0']
    if dataset == 'MATR-2':
        data_scores.loc[:, ['``Discharge\'\' Model', '``Full\'\' Model']] = '-'
        data_scores.loc[
            ['Root Mean Squared Error', 'Mean Absolute Percentage Error'],
            '``Discharge\'\' Model'
        ] = ['173±0', '9±0']
        data_scores.loc[
            ['Root Mean Squared Error', 'Mean Absolute Percentage Error'],
            '``Full\'\' Model'
        ] = ['214±0', '11±0']
    display(data_scores.style.apply(highlight_min, axis=1))
    data_scores.index = metric_abbr
    data_scores.T.to_latex(HOME / f'nmi_rebuttal_final/assets/{dataset}.tex')
    # data_scores.to_excel(f'{dataset}.xlsx')

MATR-1


Unnamed: 0,Training Mean,``Variance'' Model,``Discharge'' Model,``Full'' Model,Ridge Regression,PCR,PLSR,SVM,Random Forest,CNN,MLP,LSTM,BatLiNet
Root Mean Squared Error,399±0,136±0,86±0,100±0,117±0,104±0,105±0,140±0,169±0,125±93,162±7,118±7,63±3
Mean Absolute Error,239±0,109±0,-,-,83±0,75±0,79±0,107±0,119±0,77±54,101±2,86±5,46±2
Median Absolute Error,139±0,89±0,-,-,56±0,48±0,62±0,86±0,75±0,48±31,70±4,60±8,33±4
Mean Absolute Percentage Error,28±0,15±0,10±0,8±0,11±0,11±0,11±0,15±0,17±0,9±6,12±0,13±1,6±0
Symmetric Mean Absolute Percentage Error,30±0,15±0,-,-,11±0,10±0,11±0,15±0,16±0,10±7,12±0,12±1,6±0
Normalized Mean Absolute Error,13±0,6±0,-,-,4±0,4±0,4±0,6±0,6±0,4±3,5±0,5±0,2±0
Normalized Root Mean Squared Error,21±0,7±0,-,-,6±0,5±0,6±0,7±0,9±0,7±5,9±0,6±0,3±0


MATR-2


Unnamed: 0,Training Mean,``Variance'' Model,``Discharge'' Model,``Full'' Model,Ridge Regression,PCR,PLSR,SVM,Random Forest,CNN,MLP,LSTM,BatLiNet
Root Mean Squared Error,511±0,211±0,173±0,214±0,186±0,243±0,181±0,300±0,240±0,237±108,207±4,233±43,162±10
Mean Absolute Error,414±0,136±0,-,-,118±0,164±0,117±0,204±0,164±0,176±95,128±3,157±26,116±8
Median Absolute Error,342±0,92±0,-,-,86±0,110±0,83±0,129±0,121±0,133±84,72±3,109±15,89±10
Mean Absolute Percentage Error,36±0,12±0,9±0,11±0,10±0,14±0,11±0,18±0,14±0,16±8,11±0,14±2,11±1
Symmetric Mean Absolute Percentage Error,46±0,13±0,-,-,10±0,16±0,10±0,19±0,16±0,17±12,11±0,15±2,11±1
Normalized Mean Absolute Error,30±0,10±0,-,-,8±0,12±0,8±0,15±0,12±0,13±7,9±0,11±2,8±1
Normalized Root Mean Squared Error,37±0,15±0,-,-,13±0,17±0,13±0,22±0,17±0,17±8,15±0,17±3,12±1


HUST


Unnamed: 0,Training Mean,``Variance'' Model,``Discharge'' Model,``Full'' Model,Ridge Regression,PCR,PLSR,SVM,Random Forest,CNN,MLP,LSTM,BatLiNet
Root Mean Squared Error,420±0,398±0,322±0,335±0,1047±0,435±0,431±0,344±0,345±0,482±138,444±3,441±25,268±28
Mean Absolute Error,341±0,319±0,264±0,270±0,649±0,364±0,349±0,291±0,291±0,390±56,336±18,339±21,179±25
Median Absolute Error,332±0,250±0,239±0,226±0,333±0,356±0,307±0,246±0,254±0,323±25,265±48,277±70,125±29
Mean Absolute Percentage Error,18±0,17±0,14±0,14±0,36±0,19±0,18±0,16±0,16±0,22±4,18±1,20±1,10±2
Symmetric Mean Absolute Percentage Error,18±0,17±0,13±0,13±0,37±0,19±0,18±0,16±0,15±0,20±2,18±1,18±1,10±1
Normalized Mean Absolute Error,25±0,23±0,19±0,20±0,47±0,26±0,25±0,21±0,21±0,28±4,24±1,25±2,13±2
Normalized Root Mean Squared Error,30±0,29±0,23±0,24±0,76±0,31±0,31±0,25±0,25±0,35±10,32±0,32±2,19±2


MIX-100


Unnamed: 0,Training Mean,``Variance'' Model,``Discharge'' Model,``Full'' Model,Ridge Regression,PCR,PLSR,SVM,Random Forest,CNN,MLP,LSTM,BatLiNet
Root Mean Squared Error,573±0,521±0,1737±0,331±0,395±0,384±0,371±0,257±0,214±0,252±21,461±30,265±15,168±15
Mean Absolute Error,416±0,345±0,378±0,210±0,280±0,264±0,245±0,159±0,136±0,150±10,244±13,159±11,113±18
Median Absolute Error,307±0,223±0,158±0,132±0,207±0,169±0,155±0,100±0,86±0,76±8,113±2,83±9,75±31
Mean Absolute Percentage Error,59±0,39±0,47±0,22±0,30±0,28±0,26±0,18±0,15±0,15±1,28±1,16±1,14±5
Symmetric Mean Absolute Percentage Error,44±0,36±0,26±0,22±0,29±0,27±0,25±0,16±0,14±0,14±1,22±1,15±1,12±3
Normalized Mean Absolute Error,18±0,15±0,16±0,9±0,12±0,11±0,10±0,7±0,6±0,6±0,10±1,7±0,5±1
Normalized Root Mean Squared Error,24±0,22±0,74±0,14±0,17±0,16±0,16±0,11±0,9±0,11±1,20±1,11±1,7±1


MIX-20


Unnamed: 0,Training Mean,``Variance'' Model,``Discharge'' Model,``Full'' Model,Ridge Regression,PCR,PLSR,SVM,Random Forest,CNN,MLP,LSTM,BatLiNet
Root Mean Squared Error,594±0,600±0,988653±0,437±0,837±0,707±0,482±0,461±0,290±0,6153±10687,519±25,376±61,207±17
Mean Absolute Error,457±0,447±0,81950±0,307±0,546±0,538±0,371±0,318±0,186±0,814±878,326±18,224±29,128±8
Median Absolute Error,388±0,339±0,299±0,190±0,391±0,427±0,272±0,183±0,107±0,125±8,171±19,104±10,68±7
Mean Absolute Percentage Error,102±0,96±0,5951±0,54±0,150±0,60±0,75±0,51±0,31±0,77±69,53±2,33±5,18±3
Symmetric Mean Absolute Percentage Error,66±0,65±0,61±0,46±0,70±0,48±0,57±0,46±0,26±0,33±1,44±2,29±3,16±1
Normalized Mean Absolute Error,22±0,21±0,3914±0,15±0,26±0,15±0,18±0,15±0,9±0,39±42,16±1,11±1,6±0
Normalized Root Mean Squared Error,28±0,29±0,47214±0,21±0,40±0,20±0,23±0,22±0,14±0,294±510,25±1,18±3,10±1
