In [1]:
import pandas as pd
from glob import glob


def get_metadata(path):
    
    file_name = path.split("/")[-1][:-4]
    target, config, model, weights, freeze, batch_size, lr, epochs, fold = file_name.split("--")
    item = dict(
        target=target,
        config=config,
        model=model,
        weights=weights,
        freeze=freeze,
        batch_size=batch_size,
        lr=lr,
        epochs=epochs,
        fold=fold,
        shortcode="--".join(file_name.split("--")[:-1]),
    )
    return item


from sklearn.metrics import (
    mean_squared_error,
    r2_score,
    mean_absolute_percentage_error
)

def root_mean_squared_error(y_true, y_pred):
    return mean_squared_error(y_true, y_pred, squared=False)


def normalized_root_mean_squared_error(y_true, y_pred, norm_factor=None):
    if norm_factor is None:
        assert False, "Set norm_factor (for example the average target value for the training set)"
    rmse = root_mean_squared_error(y_true, y_pred)
    return (rmse / norm_factor)*100

In [2]:
for TARGET in ["CS", "CSE"]:

    paths = glob(f"dl_output/{TARGET}--*--vgg16--DEFAULT--False--8--0.001--50*.pkl")

    dfs = []
    for p in paths:
        df = pd.read_pickle(p)

        data = get_metadata(p)
        for k,v in data.items():
            df[k] = v

        df = df[~df.config.str.contains("Conf5")]

        dfs.append(df)

    df = pd.concat(dfs)


    MODEL = "vgg16"
    tmp_idx = df.shortcode.str.split("--").str[2:].str.join("--") == MODEL+"--DEFAULT--False--8--0.001--50"
    gb = df[tmp_idx].groupby(["config", "fold"])
    best = []
    for (config, fold), g in gb:
        idxmax = g.test_R2.astype(float).idxmax()
        best.append(g.loc[idxmax])
    best = pd.concat(best, axis=1, ignore_index=True).T
    best_mean = best.groupby("shortcode").agg({
        "test_R2": "mean",
        "train_R2": "mean",
        "test_RMSE": "mean",
        "train_RMSE": "mean",
        "test_%RMSE": "mean",
        "train_%RMSE": "mean",

    })
    best_std = best.groupby("shortcode").agg({
        "test_R2": "std",
        "train_R2": "std",

        "test_RMSE": "std",
        "train_RMSE": "std",
        "test_%RMSE": "std",
        "train_%RMSE": "std",

    })
    for col in best_mean.columns:
        best_mean[col] = best_mean[col].astype(str).str[:5] + " +/- " + best_std[col].astype(str).str[:5]

    print("="*40)
    print(MODEL.upper(), "|", "DEFAULT--False--8--0.001--50")
    best_mean.index=best_mean.index.str.split("--").str[1]
    display(best_mean)

    data = best
    data["MSE"] = data.apply(lambda row: mean_squared_error(row.test_real, row.test_pred), axis=1)
    data["R2"] = data.apply(lambda row: r2_score(row.test_real, row.test_pred), axis=1)
    data["MAPE"] = data.apply(lambda row: mean_absolute_percentage_error(row.test_real, row.test_pred), axis=1)
    data["RMSE"] = data.apply(lambda row: root_mean_squared_error(row.test_real, row.test_pred), axis=1)
    data["NRMSE"] = data.apply(lambda row: normalized_root_mean_squared_error(row.test_real, row.test_pred, norm_factor=row.train_real.mean()), axis=1)

    data["MSE_train"] = data.apply(lambda row: mean_squared_error(row.train_real, row.train_pred), axis=1)
    data["R2_train"] = data.apply(lambda row: r2_score(row.train_real, row.train_pred), axis=1)
    data["MAPE_train"] = data.apply(lambda row: mean_absolute_percentage_error(row.train_real, row.train_pred), axis=1)
    data["RMSE_train"] = data.apply(lambda row: root_mean_squared_error(row.train_real, row.train_pred), axis=1)
    data["NRMSE_train"] = data.apply(lambda row: normalized_root_mean_squared_error(row.train_real, row.train_pred, norm_factor=row.train_real.mean()), axis=1)
    data

    for _, d in data.groupby(["target", "config"]):
        d["hyperparams"] = None
        d["model_obj"] = None
        d["model"] = "DeepCNN"
        d["model_name"] = None
        d = d[['target',
         'config',
         'model_name',
         'model',
         'hyperparams',
         'fold',
         'model_obj',
         'MSE',
         'R2',
         'MAPE',
         'RMSE',
         'NRMSE',
         'MSE_train',
         'R2_train',
         'MAPE_train',
         'RMSE_train',
         'NRMSE_train']]
        display(d)
        save_path = f"../results/metrics--{d.target.iloc[0]}--{d.config.iloc[0]}--DeepCNN.pickle"
        d.to_pickle(save_path)
        print("metrics, predictions and models saved to\n", save_path)

VGG16 | DEFAULT--False--8--0.001--50


Unnamed: 0_level_0,test_R2,train_R2,test_RMSE,train_RMSE,test_%RMSE,train_%RMSE
shortcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Conf1,0.234 +/- 0.108,0.377 +/- 0.295,53.03 +/- 12.64,47.72 +/- 14.51,69.23 +/- 18.18,61.59 +/- 17.36
Conf2,0.272 +/- 0.154,0.419 +/- 0.176,51.96 +/- 14.93,46.36 +/- 8.784,68.31 +/- 20.90,60.52 +/- 10.33
Conf3,0.221 +/- 0.120,0.552 +/- 0.363,53.58 +/- 13.78,38.34 +/- 19.04,70.45 +/- 20.02,50.15 +/- 24.93
Conf4,0.298 +/- 0.106,0.323 +/- 0.105,50.86 +/- 12.84,50.47 +/- 4.217,66.44 +/- 18.48,65.63 +/- 6.238


Unnamed: 0,target,config,model_name,model,hyperparams,fold,model_obj,MSE,R2,MAPE,RMSE,NRMSE,MSE_train,R2_train,MAPE_train,RMSE_train,NRMSE_train
0,CS,Conf1,,DeepCNN,,0,,3358.714308,0.086546,2.466108,57.954416,76.777865,2105.556845,0.439177,2.740095,45.886347,60.790118
1,CS,Conf1,,DeepCNN,,1,,1294.154685,0.379016,1.124305,35.974361,44.968109,4314.514541,-0.011524,5.263354,65.684964,82.106492
2,CS,Conf1,,DeepCNN,,2,,3527.633623,0.283886,2.865045,59.393885,75.376284,1839.682973,0.466649,2.05572,42.891526,54.433277
3,CS,Conf1,,DeepCNN,,3,,4554.186713,0.192858,1.455879,67.484715,91.600632,751.941768,0.778728,1.556219,27.421557,37.220753
4,CS,Conf1,,DeepCNN,,4,,1968.795654,0.22874,8.201513,44.371113,57.431817,3217.366491,0.216367,2.147373,56.721834,73.417992


metrics, predictions and models saved to
 ../results/metrics--CS--Conf1--DeepCNN.pickle


Unnamed: 0,target,config,model_name,model,hyperparams,fold,model_obj,MSE,R2,MAPE,RMSE,NRMSE,MSE_train,R2_train,MAPE_train,RMSE_train,NRMSE_train
5,CS,Conf2,,DeepCNN,,0,,3137.291444,0.146766,3.756547,56.01153,73.702911,2388.849562,0.369215,2.783371,48.875859,64.313421
6,CS,Conf2,,DeepCNN,,1,,1082.652998,0.480502,1.118706,32.903693,41.751702,3466.555318,0.159635,3.80071,58.87746,74.709979
7,CS,Conf2,,DeepCNN,,2,,3303.194154,0.329448,3.32987,57.473421,73.062935,1840.684551,0.467713,2.478498,42.9032,54.540579
8,CS,Conf2,,DeepCNN,,3,,5109.814342,0.094384,1.636035,71.482965,97.223073,1210.907141,0.643429,2.347379,34.798091,47.328441
9,CS,Conf2,,DeepCNN,,4,,1758.211979,0.311234,5.786542,41.931038,55.849458,2148.656005,0.456665,1.602731,46.353598,61.740024


metrics, predictions and models saved to
 ../results/metrics--CS--Conf2--DeepCNN.pickle


Unnamed: 0,target,config,model_name,model,hyperparams,fold,model_obj,MSE,R2,MAPE,RMSE,NRMSE,MSE_train,R2_train,MAPE_train,RMSE_train,NRMSE_train
10,CS,Conf3,,DeepCNN,,0,,3244.913222,0.117496,2.086457,56.96414,74.520616,2106.059001,0.451942,2.109828,45.891818,60.03578
11,CS,Conf3,,DeepCNN,,1,,1329.55997,0.362027,1.070421,36.463132,46.104427,1904.610269,0.54918,2.035382,43.641841,55.181275
12,CS,Conf3,,DeepCNN,,2,,3422.030434,0.305324,1.970285,58.498123,74.300681,215.745485,0.937547,0.554234,14.688277,18.656137
13,CS,Conf3,,DeepCNN,,3,,5184.66582,0.081118,1.504977,72.004624,99.228894,595.09221,0.818378,0.793871,24.394512,33.617847
14,CS,Conf3,,DeepCNN,,4,,1937.421554,0.241031,5.549293,44.016151,58.123942,3979.873678,0.004582,2.842232,63.08624,83.30626


metrics, predictions and models saved to
 ../results/metrics--CS--Conf3--DeepCNN.pickle


Unnamed: 0,target,config,model_name,model,hyperparams,fold,model_obj,MSE,R2,MAPE,RMSE,NRMSE,MSE_train,R2_train,MAPE_train,RMSE_train,NRMSE_train
15,CS,Conf4,,DeepCNN,,0,,2858.625404,0.222553,2.7946,53.466115,69.920439,2139.319471,0.442138,2.29691,46.252778,60.487181
16,CS,Conf4,,DeepCNN,,1,,1082.151551,0.480743,0.916002,32.896072,41.416202,2918.420745,0.305864,3.055389,54.02241,68.014291
17,CS,Conf4,,DeepCNN,,2,,3422.072963,0.305315,3.437152,58.498487,74.201135,2073.345232,0.398162,2.193743,45.534001,57.756614
18,CS,Conf4,,DeepCNN,,3,,4319.937202,0.234374,1.494221,65.726229,90.279389,2787.956113,0.168473,4.212965,52.8011,72.525856
19,CS,Conf4,,DeepCNN,,4,,1913.069311,0.25057,5.707922,43.738648,56.42332,2891.90205,0.3049,2.235262,53.776408,69.372137


metrics, predictions and models saved to
 ../results/metrics--CS--Conf4--DeepCNN.pickle
VGG16 | DEFAULT--False--8--0.001--50


Unnamed: 0_level_0,test_R2,train_R2,test_RMSE,train_RMSE,test_%RMSE,train_%RMSE
shortcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Conf1,0.283 +/- 0.108,0.293 +/- 0.148,0.992 +/- 0.062,1.007 +/- 0.085,59.00 +/- 5.142,59.91 +/- 6.517
Conf2,0.260 +/- 0.161,0.316 +/- 0.101,1.003 +/- 0.065,0.994 +/- 0.076,59.59 +/- 4.568,59.02 +/- 4.552
Conf3,0.301 +/- 0.113,0.458 +/- 0.212,0.981 +/- 0.092,0.871 +/- 0.147,57.98 +/- 5.346,51.60 +/- 9.792
Conf4,0.287 +/- 0.105,0.417 +/- 0.087,0.992 +/- 0.094,0.917 +/- 0.087,58.98 +/- 5.980,54.50 +/- 5.361


Unnamed: 0,target,config,model_name,model,hyperparams,fold,model_obj,MSE,R2,MAPE,RMSE,NRMSE,MSE_train,R2_train,MAPE_train,RMSE_train,NRMSE_train
0,CSE,Conf1,,DeepCNN,,0,,0.782723,0.113575,0.896956,0.884716,50.932748,0.899737,0.438521,1.572387,0.948545,54.607333
1,CSE,Conf1,,DeepCNN,,1,,1.093818,0.303252,1.870286,1.045858,64.401534,1.324684,0.056682,1.823183,1.150949,70.872847
2,CSE,Conf1,,DeepCNN,,2,,1.000195,0.396092,2.123036,1.000098,57.470939,1.007327,0.266032,1.537747,1.003657,57.675474
3,CSE,Conf1,,DeepCNN,,3,,1.040353,0.254518,1.378651,1.019977,60.72405,0.881054,0.395222,1.456212,0.938645,55.881947
4,CSE,Conf1,,DeepCNN,,4,,1.026215,0.351788,2.527655,1.013023,61.496444,0.994591,0.310139,1.36989,0.997292,60.541505


metrics, predictions and models saved to
 ../results/metrics--CSE--Conf1--DeepCNN.pickle


Unnamed: 0,target,config,model_name,model,hyperparams,fold,model_obj,MSE,R2,MAPE,RMSE,NRMSE,MSE_train,R2_train,MAPE_train,RMSE_train,NRMSE_train
5,CSE,Conf2,,DeepCNN,,0,,0.857896,0.028441,1.171153,0.926227,53.886,0.961773,0.400798,1.83876,0.9807,57.055148
6,CSE,Conf2,,DeepCNN,,1,,0.922712,0.412244,1.44268,0.960579,58.851049,0.814907,0.414389,1.416649,0.902722,55.306355
7,CSE,Conf2,,DeepCNN,,2,,0.980498,0.407985,1.771646,0.990201,56.989007,0.918973,0.333576,1.327912,0.958631,55.172049
8,CSE,Conf2,,DeepCNN,,3,,1.135148,0.186591,1.488621,1.065433,63.008759,1.221431,0.169819,1.528337,1.105184,65.359569
9,CSE,Conf2,,DeepCNN,,4,,1.158599,0.268167,2.952475,1.076382,65.241497,1.053917,0.261986,1.751007,1.026605,62.224371


metrics, predictions and models saved to
 ../results/metrics--CSE--Conf2--DeepCNN.pickle


Unnamed: 0,target,config,model_name,model,hyperparams,fold,model_obj,MSE,R2,MAPE,RMSE,NRMSE,MSE_train,R2_train,MAPE_train,RMSE_train,NRMSE_train
10,CSE,Conf3,,DeepCNN,,0,,0.722399,0.181891,0.946131,0.849941,49.25858,0.574751,0.642903,1.394671,0.758123,43.937277
11,CSE,Conf3,,DeepCNN,,1,,0.897678,0.428191,1.19893,0.947459,57.884681,1.226632,0.111315,2.002184,1.107534,67.664441
12,CSE,Conf3,,DeepCNN,,2,,1.169873,0.293643,2.144593,1.081607,62.231938,0.710217,0.484594,1.392057,0.842744,48.488599
13,CSE,Conf3,,DeepCNN,,3,,1.115734,0.200502,1.204571,1.056283,62.472606,0.822346,0.43715,1.597435,0.906833,53.63356
14,CSE,Conf3,,DeepCNN,,4,,0.94391,0.403776,2.617948,0.97155,58.089289,0.549566,0.617231,1.191054,0.741327,44.324172


metrics, predictions and models saved to
 ../results/metrics--CSE--Conf3--DeepCNN.pickle


Unnamed: 0,target,config,model_name,model,hyperparams,fold,model_obj,MSE,R2,MAPE,RMSE,NRMSE,MSE_train,R2_train,MAPE_train,RMSE_train,NRMSE_train
15,CSE,Conf4,,DeepCNN,,0,,0.708064,0.198124,1.020415,0.841466,48.724822,0.981724,0.390607,1.884952,0.99082,57.37315
16,CSE,Conf4,,DeepCNN,,1,,0.93715,0.403048,1.51473,0.968065,60.037567,0.784933,0.433159,1.601548,0.885965,54.945855
17,CSE,Conf4,,DeepCNN,,2,,1.078456,0.348839,1.624984,1.038487,59.95108,0.615374,0.554998,1.005164,0.784458,45.28616
18,CSE,Conf4,,DeepCNN,,3,,1.177987,0.155894,1.316538,1.085351,64.039816,0.9961,0.315678,1.677223,0.998048,58.888617
19,CSE,Conf4,,DeepCNN,,4,,1.060762,0.329966,1.709271,1.029933,62.180184,0.861826,0.394362,1.304558,0.928346,56.047061


metrics, predictions and models saved to
 ../results/metrics--CSE--Conf4--DeepCNN.pickle
