In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
import time
from collections import defaultdict
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import xgboost as xgb

## Configuration

In [16]:
# Number of force cells in the robotic leg
N_CELLS = 8

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0005_19042021'
# Hyperparameters search date
HS_DATE = '20042021'
# Number of folds in cross-validation
CV = 6

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0005_19042021


## Hyperparameters seach analysis

In [17]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, 'XGB_{}'.format(HS_DATE), 'XGB_{}_*.json'.format(HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 123


In [18]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_booster,param_eta,param_gamma,param_lambda,param_max_depth,param_nthread,param_objective,param_seed,param_subsample,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Valid_Fz_MAE_mean__mean,Valid_Fz_MAE_mean__std,Valid_Fz_MAE_std__mean,Valid_Fz_MAE_std__std,Valid_Fz_MSE_mean__mean,Valid_Fz_MSE_mean__std,Valid_Fz_MSE_std__mean,Valid_Fz_MSE_std__std,Valid_Fz_R2_mean__mean,Valid_Fz_R2_mean__std,Valid_Fz_R2_std__mean,Valid_Fz_R2_std__std
0,69JLNJO5RI,gbtree,0.5,0.010,1,3,8,reg:squarederror,0,0.75,65.098543,1.218346,8.660381,0.308246,2.542358,0.517207,163.166964,15.163582,100.421780,28.924884,0.455936,0.027535,0.109046,0.025455,7.776703,0.415907,4.719952,0.558537,169.018689,27.446492,181.757054,44.655947,0.343985,0.033728,0.125440,0.013805,13.421215,0.411529,5.079392,0.202952,391.478448,16.918628,331.995217,16.365414,0.471004,0.034518,0.093580,0.047410,9.668765,0.416967,2.846736,0.268242,198.910371,18.467043,118.662945,27.952420,0.300448,0.036133,0.124305,0.040222,8.220168,0.223061,4.893021,0.350033,190.293417,13.029829,207.056361,16.405125,0.212439,0.047328,0.154688,0.018848,15.182772,0.608899,6.123708,0.600148,513.558613,35.236396,485.601923,86.764550,0.289525,0.049084,0.157594,0.087675
1,FTN2A7XGI1,gbtree,0.3,0.010,2,4,8,reg:squarederror,0,0.50,72.803450,1.055087,8.131550,0.170416,2.320002,0.143161,140.605485,5.791586,79.040401,7.089656,0.523109,0.016271,0.079205,0.005475,7.394806,0.301609,4.368400,0.435084,148.119109,17.663247,153.588425,30.499090,0.399236,0.019840,0.121197,0.018887,12.614450,0.425561,4.647628,0.166900,345.375069,19.869002,278.672046,15.135022,0.525071,0.035966,0.082801,0.036251,9.504354,0.464097,2.783744,0.366424,193.427923,19.979574,113.535239,28.901364,0.319120,0.040427,0.116487,0.029223,8.097245,0.225779,4.728623,0.344174,184.085513,14.206924,197.393287,18.571446,0.220997,0.044704,0.162695,0.019789,14.870117,0.733362,5.960564,0.685565,495.101110,32.770727,458.727406,78.059753,0.309798,0.052237,0.163587,0.092558
2,YO1GY3J97Z,gbtree,0.3,0.005,2,6,8,reg:squarederror,0,0.75,131.329646,2.483775,6.801254,0.268628,2.042166,0.329952,98.394610,8.178699,58.186013,13.196209,0.663714,0.022251,0.080217,0.012457,6.538690,0.393707,3.821124,0.475121,114.196796,17.691105,116.904412,29.840202,0.521767,0.032929,0.126657,0.023254,10.236166,0.465187,3.646358,0.284529,229.639473,19.016817,166.033110,11.773164,0.670938,0.030992,0.074857,0.034168,9.082600,0.507941,2.639738,0.316877,180.933245,23.465321,103.559918,27.227847,0.356618,0.048333,0.131221,0.023144,7.909461,0.230306,4.543698,0.441839,178.323619,15.424788,187.722294,23.132875,0.227755,0.056994,0.180358,0.025686,14.282780,0.639077,5.739485,0.712269,477.628953,25.852428,444.908473,62.876727,0.332657,0.059080,0.174887,0.100470
3,UJV92DD3F7,gbtree,0.5,0.010,0,5,8,reg:squarederror,0,0.75,104.128809,3.515423,7.418362,0.501350,2.335507,0.633854,119.208314,19.292888,75.582454,29.531979,0.601572,0.043286,0.092090,0.029076,6.949187,0.435193,4.041769,0.573714,130.292602,22.812850,134.699693,37.344446,0.455574,0.040697,0.150768,0.024200,10.945919,0.397552,4.221357,0.294928,270.190197,21.788081,221.423717,32.777151,0.631204,0.028211,0.082450,0.041121,9.341590,0.441092,2.702746,0.340526,187.560997,21.880392,106.484180,27.019316,0.331383,0.042428,0.134246,0.032062,8.003050,0.195952,4.628647,0.357982,180.977785,11.834386,193.684427,18.200036,0.225054,0.054762,0.177774,0.022415,14.422291,0.606310,5.814055,0.771452,480.552056,32.034211,446.787593,77.747788,0.326294,0.059394,0.172891,0.103983
4,UIIA5POF5I,gbtree,0.3,0.010,2,7,8,reg:squarederror,0,0.75,155.088028,3.172026,6.294416,0.240647,2.021956,0.324676,86.543521,6.616950,54.718179,11.961837,0.707899,0.019719,0.086159,0.010028,6.097978,0.485892,3.572792,0.554560,97.794590,19.388475,99.352360,31.542105,0.577916,0.044321,0.138918,0.040608,9.216268,0.554022,3.218922,0.285926,186.932261,20.638034,125.747019,14.075482,0.723520,0.036936,0.076536,0.032008,8.978160,0.467510,2.542490,0.331031,178.388202,23.855861,99.675800,27.393087,0.359898,0.053070,0.141001,0.025642,7.842939,0.219761,4.432838,0.368441,176.224368,15.657473,184.430153,23.924132,0.225376,0.055442,0.186949,0.028756,14.157323,0.660834,5.677264,0.756870,474.755313,30.919891,438.566749,65.050225,0.337478,0.050418,0.181055,0.097151
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
118,8XRNAK6H9B,gbtree,0.4,0.050,2,8,8,reg:squarederror,0,1.00,196.899457,5.068455,5.809364,0.343563,2.040881,0.503744,75.818661,10.080077,51.931043,19.620255,0.743002,0.025383,0.108610,0.018926,5.789487,0.695007,3.474398,0.836647,90.647517,31.480053,93.373273,48.053907,0.616614,0.052345,0.147892,0.031640,7.880606,0.501369,2.771567,0.376948,141.767760,14.888642,98.141837,10.816475,0.787100,0.029375,0.068803,0.033504,8.938685,0.490724,2.493701,0.353458,177.879453,23.810811,97.728606,27.705479,0.357773,0.057061,0.148745,0.030238,7.872576,0.256062,4.461127,0.382873,178.976547,16.590510,187.381172,25.272975,0.213508,0.059074,0.188839,0.029837,14.103159,0.626670,5.696698,0.702248,482.371961,28.346395,454.304864,52.562666,0.330400,0.048349,0.183324,0.094388
119,PFWE224BA5,gbtree,0.4,0.005,2,4,8,reg:squarederror,0,1.00,99.405916,2.415991,8.057840,0.319555,2.449871,0.439062,138.468480,12.621986,84.469037,21.919993,0.534869,0.028081,0.092857,0.022045,7.369684,0.423978,4.402368,0.628493,147.421462,24.657249,154.917325,43.237132,0.405740,0.027928,0.129141,0.011227,12.343596,0.478478,4.683886,0.198619,332.743428,22.219922,272.478044,13.846841,0.548041,0.037084,0.082002,0.033863,9.439124,0.468855,2.746469,0.268015,190.551247,20.355638,111.167807,26.591859,0.326239,0.039341,0.120096,0.034165,8.013045,0.174589,4.649355,0.319782,179.683720,10.872351,191.504461,13.580132,0.230676,0.048943,0.166423,0.023578,14.808933,0.692099,5.946387,0.706575,494.699427,35.522366,464.589255,81.151664,0.311068,0.049260,0.166534,0.097553
120,WK7G2BI77W,gbtree,0.3,0.010,1,8,8,reg:squarederror,0,1.00,203.976877,4.721059,5.737759,0.273420,1.855349,0.203941,72.422355,6.260769,44.617223,6.191393,0.751368,0.025265,0.089382,0.019091,5.867993,0.603874,3.435200,0.676632,92.103131,26.453346,93.735318,38.017957,0.606559,0.051327,0.137934,0.036382,8.124729,0.584539,2.827746,0.469751,148.307618,21.414457,97.140128,20.120733,0.777734,0.032684,0.068878,0.036092,8.922712,0.527059,2.517248,0.311742,177.870967,23.701475,98.757616,23.920635,0.360192,0.054013,0.144085,0.026517,7.850801,0.282187,4.443856,0.396747,178.115529,16.435981,186.929883,24.321459,0.218413,0.056812,0.187177,0.028383,13.994986,0.564889,5.643784,0.693085,473.442149,21.199245,437.835530,44.357383,0.337801,0.058053,0.185817,0.101036
121,ZDMKGSRAOE,gbtree,0.3,0.005,2,7,8,reg:squarederror,0,0.75,153.540638,3.200350,6.294416,0.240647,2.021956,0.324676,86.543521,6.616950,54.718179,11.961837,0.707899,0.019719,0.086159,0.010028,6.097978,0.485892,3.572792,0.554560,97.794590,19.388475,99.352360,31.542105,0.577916,0.044321,0.138918,0.040608,9.216268,0.554022,3.218922,0.285926,186.932261,20.638034,125.747019,14.075482,0.723520,0.036936,0.076536,0.032008,8.978160,0.467510,2.542490,0.331031,178.388202,23.855861,99.675800,27.393087,0.359898,0.053070,0.141001,0.025642,7.842939,0.219761,4.432838,0.368441,176.224368,15.657473,184.430153,23.924132,0.225376,0.055442,0.186949,0.028756,14.157323,0.660834,5.677264,0.756870,474.755313,30.919891,438.566749,65.050225,0.337478,0.050418,0.181055,0.097151


In [19]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy', 'Fz']]].mean(axis=1)

In [20]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_booster,param_eta,param_gamma,param_lambda,param_max_depth,param_nthread,param_objective,param_seed,param_subsample,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Valid_Fz_MAE_mean__mean,Valid_Fz_MAE_mean__std,Valid_Fz_MAE_std__mean,Valid_Fz_MAE_std__std,Valid_Fz_MSE_mean__mean,Valid_Fz_MSE_mean__std,Valid_Fz_MSE_std__mean,Valid_Fz_MSE_std__std,Valid_Fz_R2_mean__mean,Valid_Fz_R2_mean__std,Valid_Fz_R2_std__mean,Valid_Fz_R2_std__std,Train_MAE,Train_MSE,Train_R2,Valid_MAE,Valid_MSE,Valid_R2
4,UIIA5POF5I,gbtree,0.3,0.010,2,7,8,reg:squarederror,0,0.75,155.088028,3.172026,6.294416,0.240647,2.021956,0.324676,86.543521,6.616950,54.718179,11.961837,0.707899,0.019719,0.086159,0.010028,6.097978,0.485892,3.572792,0.554560,97.794590,19.388475,99.352360,31.542105,0.577916,0.044321,0.138918,0.040608,9.216268,0.554022,3.218922,0.285926,186.932261,20.638034,125.747019,14.075482,0.723520,0.036936,0.076536,0.032008,8.978160,0.467510,2.542490,0.331031,178.388202,23.855861,99.675800,27.393087,0.359898,0.053070,0.141001,0.025642,7.842939,0.219761,4.432838,0.368441,176.224368,15.657473,184.430153,23.924132,0.225376,0.055442,0.186949,0.028756,14.157323,0.660834,5.677264,0.756870,474.755313,30.919891,438.566749,65.050225,0.337478,0.050418,0.181055,0.097151,7.202888,123.756791,0.669778,10.326141,276.455961,0.307584
121,ZDMKGSRAOE,gbtree,0.3,0.005,2,7,8,reg:squarederror,0,0.75,153.540638,3.200350,6.294416,0.240647,2.021956,0.324676,86.543521,6.616950,54.718179,11.961837,0.707899,0.019719,0.086159,0.010028,6.097978,0.485892,3.572792,0.554560,97.794590,19.388475,99.352360,31.542105,0.577916,0.044321,0.138918,0.040608,9.216268,0.554022,3.218922,0.285926,186.932261,20.638034,125.747019,14.075482,0.723520,0.036936,0.076536,0.032008,8.978160,0.467510,2.542490,0.331031,178.388202,23.855861,99.675800,27.393087,0.359898,0.053070,0.141001,0.025642,7.842939,0.219761,4.432838,0.368441,176.224368,15.657473,184.430153,23.924132,0.225376,0.055442,0.186949,0.028756,14.157323,0.660834,5.677264,0.756870,474.755313,30.919891,438.566749,65.050225,0.337478,0.050418,0.181055,0.097151,7.202888,123.756791,0.669778,10.326141,276.455961,0.307584
75,TKZ3SX54J2,gbtree,0.3,0.050,2,7,8,reg:squarederror,0,0.75,153.731990,3.019236,6.294416,0.240647,2.021956,0.324676,86.543521,6.616950,54.718179,11.961837,0.707899,0.019719,0.086159,0.010028,6.097978,0.485892,3.572792,0.554560,97.794590,19.388475,99.352360,31.542105,0.577916,0.044321,0.138918,0.040608,9.216268,0.554022,3.218922,0.285926,186.932261,20.638034,125.747019,14.075482,0.723520,0.036936,0.076536,0.032008,8.978160,0.467510,2.542490,0.331031,178.388202,23.855861,99.675800,27.393087,0.359898,0.053070,0.141001,0.025642,7.842939,0.219761,4.432837,0.368441,176.224372,15.657475,184.430150,23.924129,0.225375,0.055442,0.186949,0.028756,14.157323,0.660834,5.677264,0.756870,474.755313,30.919891,438.566749,65.050225,0.337478,0.050418,0.181055,0.097151,7.202888,123.756791,0.669778,10.326141,276.455963,0.307584
91,8XY4213EEP,gbtree,0.3,0.010,0,8,8,reg:squarederror,0,0.75,177.400117,3.839991,5.790228,0.324547,2.020671,0.391398,74.054924,8.483433,49.932794,13.891607,0.750307,0.024656,0.091934,0.017148,5.802969,0.607896,3.360440,0.714051,88.825517,24.791089,88.682749,34.834281,0.611216,0.048073,0.141269,0.040425,8.152538,0.533751,2.822437,0.406030,147.874712,17.531976,94.621264,14.549147,0.775742,0.031362,0.073325,0.036804,8.924186,0.561488,2.512101,0.337308,177.227975,24.713891,97.516687,25.522056,0.360086,0.059941,0.150123,0.029354,7.805313,0.259178,4.393205,0.396740,175.990271,16.729043,184.148301,25.075015,0.220827,0.059752,0.188526,0.028088,13.983373,0.578395,5.603109,0.640614,474.759895,21.172475,445.734456,47.334010,0.341644,0.054319,0.183039,0.094337,6.581912,103.585051,0.712422,10.237624,275.992714,0.307519
7,G60E820FRQ,gbtree,0.3,0.005,0,8,8,reg:squarederror,0,0.75,175.505264,3.596184,5.790228,0.324547,2.020671,0.391398,74.054924,8.483434,49.932793,13.891607,0.750307,0.024656,0.091934,0.017148,5.802969,0.607896,3.360440,0.714051,88.825517,24.791089,88.682749,34.834281,0.611216,0.048073,0.141269,0.040425,8.140793,0.531502,2.838163,0.396547,147.669434,17.502275,94.844503,14.576666,0.777210,0.031093,0.071144,0.037292,8.924186,0.561488,2.512101,0.337308,177.227976,24.713891,97.516687,25.522056,0.360086,0.059941,0.150123,0.029354,7.805313,0.259178,4.393205,0.396740,175.990270,16.729043,184.148301,25.075015,0.220827,0.059752,0.188526,0.028088,13.982631,0.577747,5.604133,0.641824,474.773053,21.190801,445.722714,47.322263,0.341549,0.054351,0.182930,0.094333,6.577997,103.516625,0.712911,10.237377,275.997100,0.307487
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
67,XAS93RH1V0,gbtree,0.3,0.005,1,3,8,reg:squarederror,0,1.00,76.644844,0.704555,8.833268,0.228775,2.588302,0.204849,167.363427,10.528999,98.818256,13.188300,0.439575,0.022575,0.087791,0.017311,7.902968,0.235182,4.740477,0.311502,172.725997,14.327931,183.740934,22.386411,0.322175,0.015863,0.110351,0.016408,13.808963,0.402491,5.164766,0.169320,412.897210,21.948463,350.446443,15.754030,0.442483,0.034952,0.089238,0.032008,9.742511,0.492889,2.991434,0.437757,204.460448,22.714025,128.880397,36.519706,0.293614,0.039180,0.110364,0.035998,8.287129,0.288150,4.943737,0.317876,194.675501,16.748244,215.293420,19.229717,0.206492,0.041521,0.142937,0.023193,15.283627,0.749223,5.934859,0.564441,515.588586,32.058946,472.512366,72.911592,0.276170,0.049205,0.161650,0.079121,10.181733,250.995545,0.401411,11.104422,304.908178,0.258759
15,MWK74CGKOO,gbtree,0.3,0.005,1,3,8,reg:squarederror,0,0.50,56.339713,0.509179,8.847647,0.259727,2.616724,0.244966,168.501477,12.037874,100.693877,15.166963,0.437174,0.026239,0.092902,0.020349,7.927647,0.239138,4.788287,0.295888,175.718158,14.904770,188.193572,23.266535,0.318907,0.018651,0.109418,0.015123,13.777552,0.366123,5.169470,0.159497,413.125748,20.584532,355.012184,14.887012,0.445223,0.031887,0.086578,0.033847,9.743784,0.504088,2.923219,0.419714,203.912833,21.234995,124.569649,33.972128,0.290311,0.041904,0.112591,0.040753,8.295251,0.304020,4.973948,0.393699,196.806829,18.358638,219.938325,23.900770,0.205731,0.041139,0.142490,0.021996,15.293896,0.771030,5.969752,0.598162,517.027869,35.482946,476.543691,81.840500,0.279053,0.046942,0.160206,0.076994,10.184282,252.448461,0.400435,11.110977,305.915844,0.258365
44,EUBAU593MK,gbtree,0.3,0.050,2,3,8,reg:squarederror,0,0.50,56.223704,0.543661,8.845616,0.262132,2.616087,0.245076,168.380666,12.137230,100.633315,15.229113,0.437454,0.026439,0.093020,0.020622,7.923044,0.239443,4.792514,0.293747,175.687263,14.901334,188.238448,23.267861,0.319958,0.018066,0.108056,0.013590,13.786273,0.378496,5.170601,0.159836,413.675073,21.275247,354.936740,14.812728,0.444368,0.033281,0.086377,0.033551,9.744980,0.499978,2.921230,0.420733,203.835978,21.114613,124.432891,34.067184,0.290207,0.041704,0.112153,0.040989,8.293713,0.302964,4.975298,0.394771,196.811098,18.360051,219.945819,23.900001,0.205627,0.041142,0.142565,0.022021,15.297106,0.772721,5.971350,0.599491,517.153946,35.488244,476.555858,81.842369,0.278877,0.046622,0.160417,0.076926,10.184977,252.581001,0.400593,11.111933,305.933674,0.258237
84,OS9D1ANCWO,gbtree,0.3,0.005,1,3,8,reg:squarederror,0,0.75,67.301810,0.616080,8.901032,0.192718,2.653298,0.226508,171.817551,9.085201,104.046029,14.417723,0.426976,0.018151,0.100884,0.026770,7.853594,0.214736,4.696650,0.287175,169.711974,12.336602,179.104397,20.896949,0.328447,0.014271,0.108991,0.011323,13.740485,0.378942,5.121020,0.120246,408.680515,19.157052,346.762128,14.310248,0.447516,0.035686,0.082834,0.033863,9.755718,0.483621,3.002743,0.402370,205.659773,21.198513,128.586279,34.271390,0.288433,0.038319,0.115225,0.038180,8.305280,0.293236,4.978559,0.335470,195.573358,18.824472,217.313054,23.197974,0.205453,0.040183,0.140871,0.024709,15.340766,0.751931,5.998239,0.605835,519.923983,34.591185,481.591867,83.727587,0.272600,0.048193,0.166679,0.087475,10.165037,250.070013,0.400980,11.133921,307.052371,0.255495


In [21]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'booster': 'gbtree', 'eta': 0.3, 'gamma': 0.01, 'lambda': 2, 'max_depth': 7, 'nthread': 8, 'objective': 'reg:squarederror', 'seed': 0, 'subsample': 0.75}


## Best model

In [22]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_train_{}.npy'.format(DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_test_{}.npy'.format(DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_train_{}.npy'.format(DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_test_{}.npy'.format(DATA_ID)))

In [25]:
results = defaultdict(list)
tr_time = []
for target in range(Y_train.shape[1]):

    dtrain = xgb.DMatrix(data=X_train, label=Y_train[:, target])
    dtest = xgb.DMatrix(data=X_test, label=Y_test[:, target])

    callbacks = []#[xgb.callback.EarlyStopping(rounds=5, metric_name='rmse', maximize=False, save_best=True)]
    
    t_start = time.time()
    model = xgb.train(best_params, dtrain, evals=[(dtest, 'rmse')], callbacks=callbacks, verbose_eval=False)
    tr_time.append(time.time() - t_start)
    
    # Save the model
    model.save_model(os.path.join(RESULTS_PATH, DATA_ID, 'XGB_{}'.format(HS_DATE), 'XGB_best_model_{}_{}_{}.joblib'.format(target, HS_DATE, DATA_ID)))
    
    train_preds = model.predict(dtrain)
    test_preds = model.predict(dtest)

    results['Train_MAE'].append(mean_absolute_error(Y_train[:, target], train_preds))
    results['Train_MSE'].append(mean_squared_error(Y_train[:, target], train_preds))
    results['Train_R2'].append(r2_score(Y_train[:, target], train_preds))
    results['Test_MAE'].append(mean_absolute_error(Y_test[:, target], test_preds))
    results['Test_MSE'].append(mean_squared_error(Y_test[:, target], test_preds))
    results['Test_R2'].append(r2_score(Y_test[:, target], test_preds))

print('Training time: {:.4f}'.format(sum(tr_time)))

Training time: 421.2417


In [26]:
# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy', 'Fz']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results['_'.join([subset, loss])][i + f] for i in range(0, N_CELLS * 3, 3)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

Train Fx MAE: 5.9629 ± 1.6847
Train Fx MSE: 76.4576 ± 42.7314
Train Fx R2: 0.7376 ± 0.0517
Train Fy MAE: 5.7177 ± 3.1862
Train Fy MSE: 84.2443 ± 78.6247
Train Fy R2: 0.6201 ± 0.1068
Train Fz MAE: 9.1344 ± 3.2695
Train Fz MSE: 185.7676 ± 132.3666
Train Fz R2: 0.7378 ± 0.0519
Test Fx MAE: 11.0853 ± 3.8290
Test Fx MSE: 273.7327 ± 182.5694
Test Fx R2: 0.1542 ± 0.2007
Test Fy MAE: 9.8996 ± 6.3373
Test Fy MSE: 324.9241 ± 389.9810
Test Fy R2: 0.0324 ± 0.1978
Test Fz MAE: 16.9357 ± 7.2828
Test Fz MSE: 685.7255 ± 688.0254
Test Fz R2: 0.1111 ± 0.2064


In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()