In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
import time
from collections import defaultdict
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import xgboost as xgb

## Configuration

In [2]:
# Number of force cells in the robotic leg
N_CELLS = 8

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0005_19042021'
# Hyperparameters search date
HS_DATE = '19042021'
# Number of folds in cross-validation
CV = 6

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0005_19042021


## Hyperparameters seach analysis

In [3]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, 'XGB_{}'.format(HS_DATE), 'XGB_{}_*.json'.format(HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 18


In [4]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_booster,param_eta,param_gamma,param_lambda,param_max_depth,param_nthread,param_objective,param_seed,param_subsample,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Test_Fx_MAE_mean__mean,Test_Fx_MAE_mean__std,Test_Fx_MAE_std__mean,Test_Fx_MAE_std__std,Test_Fx_MSE_mean__mean,Test_Fx_MSE_mean__std,Test_Fx_MSE_std__mean,Test_Fx_MSE_std__std,Test_Fx_R2_mean__mean,Test_Fx_R2_mean__std,Test_Fx_R2_std__mean,Test_Fx_R2_std__std,Test_Fy_MAE_mean__mean,Test_Fy_MAE_mean__std,Test_Fy_MAE_std__mean,Test_Fy_MAE_std__std,Test_Fy_MSE_mean__mean,Test_Fy_MSE_mean__std,Test_Fy_MSE_std__mean,Test_Fy_MSE_std__std,Test_Fy_R2_mean__mean,Test_Fy_R2_mean__std,Test_Fy_R2_std__mean,Test_Fy_R2_std__std,Test_Fz_MAE_mean__mean,Test_Fz_MAE_mean__std,Test_Fz_MAE_std__mean,Test_Fz_MAE_std__std,Test_Fz_MSE_mean__mean,Test_Fz_MSE_mean__std,Test_Fz_MSE_std__mean,Test_Fz_MSE_std__std,Test_Fz_R2_mean__mean,Test_Fz_R2_mean__std,Test_Fz_R2_std__mean,Test_Fz_R2_std__std
0,P63LB1GSD7,gbtree,0.2,0.01,1.0,15,8,reg:squarederror,0,1.0,407.360964,7.571134,3.929873,0.295803,1.649186,0.178539,35.166832,5.004749,28.175675,7.174984,0.880481,0.016705,0.08071,0.010682,4.359995,1.050321,2.769038,1.00063,58.059231,28.185341,65.648273,33.923707,0.773871,0.071547,0.130074,0.057365,5.539042,0.383757,2.128325,0.407765,62.455117,8.003072,39.499846,9.401047,0.903713,0.0146,0.038891,0.019795,9.135363,0.661232,2.56551,0.339938,189.357899,26.744436,103.001402,22.880423,0.314785,0.061079,0.165634,0.019312,8.045571,0.38839,4.379696,0.3817,188.01716,22.946639,195.167502,32.351094,0.141455,0.073916,0.222901,0.038939,14.634289,0.930613,6.008883,0.778121,535.661219,37.497986,492.182347,55.38876,0.257017,0.067955,0.216929,0.108998
1,H67AXHLY06,gbtree,0.3,0.05,0.8,30,8,reg:squarederror,0,1.0,712.268577,33.324726,2.840873,0.520416,2.088371,0.314985,25.318418,8.756451,32.253344,16.151761,0.914137,0.022135,0.101526,0.034722,3.83424,1.362203,2.86519,1.059335,53.615998,31.819607,66.375297,32.74552,0.792843,0.079514,0.160707,0.044359,3.117479,1.260741,1.720119,0.614502,23.533447,14.794557,21.905091,12.319507,0.958212,0.025013,0.03615,0.01652,9.443974,0.686706,2.536485,0.253819,201.110101,27.294826,105.538318,23.753954,0.2668,0.064058,0.180192,0.024382,8.40996,0.409865,4.539365,0.35287,201.738047,25.534681,207.699643,36.866825,0.06651,0.071279,0.260368,0.063677,14.917142,0.772594,6.045794,0.651397,561.25215,36.653781,508.504844,62.225038,0.212599,0.052057,0.228067,0.091205
2,7T8QE34V4X,gbtree,0.3,0.05,1.0,30,8,reg:squarederror,0,0.75,624.95802,24.826554,2.636239,0.615589,1.805612,0.312047,20.690286,7.654565,22.926566,9.335936,0.927869,0.023936,0.077042,0.016949,3.539635,1.302609,2.581526,1.108669,45.205678,28.868606,55.107759,32.364831,0.811,0.075645,0.163944,0.052588,3.030591,1.172436,1.578006,0.602306,23.119706,14.353755,21.975957,12.484869,0.960531,0.024565,0.034917,0.016054,9.304694,0.689449,2.490216,0.240271,194.330475,27.289184,101.828443,22.1977,0.290838,0.064583,0.175559,0.02317,8.295945,0.401797,4.464522,0.380635,196.188216,25.654423,201.287413,36.34975,0.090996,0.072403,0.248597,0.052953,14.798087,0.787244,6.05586,0.698169,554.863563,37.706478,508.80133,66.258032,0.227665,0.056039,0.224512,0.091493
3,Q4ZPOSZ4QC,gbtree,0.4,0.01,1.0,30,8,reg:squarederror,0,1.0,704.790866,47.364074,2.646745,0.786423,2.111895,0.361609,24.035163,10.384174,30.832964,14.576846,0.919199,0.032336,0.093733,0.032523,3.460587,1.333607,2.630536,1.011383,44.256205,30.185975,55.242284,39.627895,0.816564,0.075983,0.149609,0.041089,2.603618,1.626497,1.856446,0.890482,21.68969,21.047587,22.08582,19.38914,0.962009,0.034276,0.037052,0.026933,9.451413,0.673065,2.510028,0.227064,200.668036,26.562711,104.008954,22.94738,0.265779,0.063461,0.181715,0.023684,8.416528,0.388255,4.533754,0.356665,202.833387,26.224958,208.603954,37.622008,0.060329,0.073728,0.260935,0.060798,14.932658,0.700271,6.047708,0.628483,564.388882,37.226558,511.12654,64.619378,0.207896,0.054399,0.22926,0.087801
4,SQ137IWT23,gbtree,0.4,0.1,0.8,6,8,reg:squarederror,0,0.5,105.037324,2.74321,6.836908,0.364652,2.071451,0.44464,100.583698,11.683508,59.92314,19.661538,0.65577,0.032064,0.099622,0.023449,6.479819,0.453406,3.7144,0.464092,112.177116,22.064625,114.086341,34.775933,0.519527,0.043509,0.144558,0.028825,9.97118,0.393512,3.570111,0.163951,221.025561,14.212194,164.128499,12.176379,0.682718,0.029614,0.077691,0.027732,9.110351,0.495971,2.61718,0.306478,181.719816,23.090115,102.551094,25.662146,0.349562,0.053085,0.136338,0.031301,7.882462,0.244557,4.483528,0.413694,176.240137,15.492662,184.317626,22.650116,0.229189,0.051201,0.183571,0.031421,14.300803,0.596519,5.739543,0.680234,477.979273,28.223068,444.394189,64.735382,0.328856,0.056647,0.178452,0.098256
5,WFC1L57OY7,gbtree,0.4,0.01,0.8,6,8,reg:squarederror,0,0.75,128.852457,4.48254,6.641679,0.274204,1.902068,0.273806,94.517723,8.138756,52.820229,11.399821,0.673734,0.025379,0.082815,0.014188,6.58561,0.650551,3.907084,0.793151,119.674487,34.875506,125.134847,53.68617,0.515775,0.056879,0.1448,0.01813,10.084164,0.574321,3.596153,0.236421,223.374698,20.5892,159.871132,13.83609,0.676088,0.040836,0.082241,0.031188,9.090498,0.496375,2.583047,0.295293,181.771776,22.963477,101.914846,27.225507,0.350174,0.051087,0.131658,0.029234,7.903189,0.266383,4.516694,0.441777,178.483862,15.31232,188.698256,22.990408,0.226248,0.058746,0.179775,0.027372,14.285223,0.606498,5.706206,0.694397,474.895773,27.284464,442.082034,60.682194,0.335315,0.058059,0.17195,0.098401
6,X2M36O6DH3,gbtree,0.4,0.01,1.0,10,8,reg:squarederror,0,1.0,246.144911,9.487769,4.805576,0.420795,1.888233,0.55294,55.31228,10.555639,42.552025,19.419592,0.812687,0.028575,0.103424,0.019271,5.263627,0.999363,3.257158,0.992078,80.303304,37.496748,88.669983,50.904713,0.671332,0.077595,0.149706,0.024249,6.318958,0.756582,2.368404,0.647758,95.315985,19.246409,66.126261,18.80987,0.851391,0.037975,0.067848,0.040316,8.962973,0.601726,2.473893,0.337558,181.620747,27.603399,97.296354,27.294735,0.342515,0.066644,0.157687,0.024215,7.87894,0.338865,4.38586,0.408588,179.8684,19.895146,186.381375,28.62224,0.188431,0.072421,0.205924,0.030558,14.096669,0.650531,5.793156,0.747725,491.684721,35.222493,461.66695,54.803342,0.32056,0.055776,0.198558,0.094055
7,WM6A1S4U3W,gbtree,0.2,0.01,1.2,4,8,reg:squarederror,0,0.75,89.252449,0.586679,8.535054,0.163454,2.614943,0.174418,156.157955,6.681633,94.350302,10.233762,0.477604,0.009818,0.077367,0.01323,7.593009,0.138688,4.398026,0.144129,154.863842,6.030149,158.402068,9.308903,0.365257,0.006576,0.118008,0.009731,13.331772,0.188941,4.924855,0.074306,392.361869,10.26378,326.17704,12.938858,0.469968,0.01754,0.071766,0.015439,9.72579,0.588451,3.048632,0.514188,203.79382,23.473756,125.716649,35.29346,0.292105,0.041801,0.109736,0.030895,8.24372,0.272811,4.82511,0.396018,190.740192,16.817063,207.360925,21.85635,0.206255,0.040508,0.151919,0.020014,15.428264,0.950209,6.121568,0.722889,535.761671,36.073797,493.008236,74.399368,0.257141,0.067469,0.170609,0.103448
8,UZN5JZ9SD9,gbtree,0.3,0.1,1.2,15,8,reg:squarederror,0,0.75,334.583869,8.885188,3.402077,0.534963,1.745105,0.19971,29.915026,7.326268,27.395353,7.789115,0.896996,0.024096,0.084987,0.013757,4.189589,1.052695,2.742377,0.917512,54.308792,26.914881,62.574729,29.690114,0.779416,0.06413,0.134298,0.037382,3.896318,0.988338,1.579005,0.619063,37.998527,16.268911,26.871014,14.804675,0.935482,0.026419,0.042828,0.01631,9.026674,0.666639,2.374402,0.240099,184.785738,27.34288,95.330261,23.865734,0.32396,0.066524,0.172123,0.016996,8.044724,0.358262,4.380703,0.393833,187.533423,22.543142,194.780314,32.699635,0.141087,0.069707,0.225941,0.042268,14.298267,0.702281,5.854042,0.664065,519.244815,34.957753,484.5292,64.404344,0.279349,0.060658,0.214174,0.094308
9,H66HMVC8X7,gbtree,0.4,0.01,1.2,15,8,reg:squarederror,0,0.5,260.047451,11.453883,3.512153,0.568856,1.961144,0.322954,33.243863,8.75618,32.449308,12.011767,0.887318,0.029213,0.096256,0.031448,4.207942,1.043767,2.973155,1.020394,58.775623,31.025934,76.93721,44.104504,0.773937,0.064531,0.144261,0.04698,3.780018,0.964495,1.648936,0.841083,38.005346,15.987519,28.777975,19.507521,0.93175,0.033313,0.04796,0.032183,9.115586,0.671247,2.423028,0.264624,187.553092,28.156992,97.251001,24.763163,0.31587,0.063939,0.168321,0.022972,8.126598,0.355654,4.466562,0.391519,190.831297,20.788623,197.508092,30.163479,0.127622,0.068225,0.228024,0.044798,14.439205,0.704382,5.878653,0.637437,526.047775,36.995728,485.433447,65.233605,0.266892,0.062841,0.216735,0.090125


In [5]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy', 'Fz']]].mean(axis=1)

In [6]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_booster,param_eta,param_gamma,param_lambda,param_max_depth,param_nthread,param_objective,param_seed,param_subsample,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Test_Fx_MAE_mean__mean,Test_Fx_MAE_mean__std,Test_Fx_MAE_std__mean,Test_Fx_MAE_std__std,Test_Fx_MSE_mean__mean,Test_Fx_MSE_mean__std,Test_Fx_MSE_std__mean,Test_Fx_MSE_std__std,Test_Fx_R2_mean__mean,Test_Fx_R2_mean__std,Test_Fx_R2_std__mean,Test_Fx_R2_std__std,Test_Fy_MAE_mean__mean,Test_Fy_MAE_mean__std,Test_Fy_MAE_std__mean,Test_Fy_MAE_std__std,Test_Fy_MSE_mean__mean,Test_Fy_MSE_mean__std,Test_Fy_MSE_std__mean,Test_Fy_MSE_std__std,Test_Fy_R2_mean__mean,Test_Fy_R2_mean__std,Test_Fy_R2_std__mean,Test_Fy_R2_std__std,Test_Fz_MAE_mean__mean,Test_Fz_MAE_mean__std,Test_Fz_MAE_std__mean,Test_Fz_MAE_std__std,Test_Fz_MSE_mean__mean,Test_Fz_MSE_mean__std,Test_Fz_MSE_std__mean,Test_Fz_MSE_std__std,Test_Fz_R2_mean__mean,Test_Fz_R2_mean__std,Test_Fz_R2_std__mean,Test_Fz_R2_std__std,Train_MAE,Train_MSE,Train_R2,Test_MAE,Test_MSE,Test_R2
15,II2ZWODEDU,gbtree,0.3,0.01,1.0,6,8,reg:squarederror,0,0.75,132.427205,2.486171,6.793238,0.238745,2.099205,0.305912,98.712573,7.304213,60.525326,12.370605,0.665521,0.018673,0.078124,0.013107,6.584058,0.451156,3.830365,0.530589,115.542388,21.903064,116.655806,33.594616,0.515698,0.034932,0.128127,0.018096,10.261865,0.493881,3.678117,0.238681,231.736952,19.367401,168.811004,9.571463,0.669413,0.034322,0.076154,0.035179,9.089635,0.497091,2.647547,0.299751,181.362189,23.181433,103.555434,26.957335,0.355081,0.048472,0.128365,0.024754,7.912292,0.231514,4.521532,0.459106,178.108281,15.218504,187.573071,22.788785,0.227676,0.05632,0.179241,0.027514,14.299141,0.653274,5.74392,0.724106,479.134282,27.029994,443.765247,56.188805,0.330557,0.056992,0.174416,0.101298,7.87972,148.663971,0.616877,10.433689,279.534917,0.304438
5,WFC1L57OY7,gbtree,0.4,0.01,0.8,6,8,reg:squarederror,0,0.75,128.852457,4.48254,6.641679,0.274204,1.902068,0.273806,94.517723,8.138756,52.820229,11.399821,0.673734,0.025379,0.082815,0.014188,6.58561,0.650551,3.907084,0.793151,119.674487,34.875506,125.134847,53.68617,0.515775,0.056879,0.1448,0.01813,10.084164,0.574321,3.596153,0.236421,223.374698,20.5892,159.871132,13.83609,0.676088,0.040836,0.082241,0.031188,9.090498,0.496375,2.583047,0.295293,181.771776,22.963477,101.914846,27.225507,0.350174,0.051087,0.131658,0.029234,7.903189,0.266383,4.516694,0.441777,178.483862,15.31232,188.698256,22.990408,0.226248,0.058746,0.179775,0.027372,14.285223,0.606498,5.706206,0.694397,474.895773,27.284464,442.082034,60.682194,0.335315,0.058059,0.17195,0.098401,7.770485,145.855636,0.621866,10.426303,278.383804,0.303912
4,SQ137IWT23,gbtree,0.4,0.1,0.8,6,8,reg:squarederror,0,0.5,105.037324,2.74321,6.836908,0.364652,2.071451,0.44464,100.583698,11.683508,59.92314,19.661538,0.65577,0.032064,0.099622,0.023449,6.479819,0.453406,3.7144,0.464092,112.177116,22.064625,114.086341,34.775933,0.519527,0.043509,0.144558,0.028825,9.97118,0.393512,3.570111,0.163951,221.025561,14.212194,164.128499,12.176379,0.682718,0.029614,0.077691,0.027732,9.110351,0.495971,2.61718,0.306478,181.719816,23.090115,102.551094,25.662146,0.349562,0.053085,0.136338,0.031301,7.882462,0.244557,4.483528,0.413694,176.240137,15.492662,184.317626,22.650116,0.229189,0.051201,0.183571,0.031421,14.300803,0.596519,5.739543,0.680234,477.979273,28.223068,444.394189,64.735382,0.328856,0.056647,0.178452,0.098256,7.762636,144.595458,0.619338,10.431205,278.646409,0.302536
17,9J5F9BD0D8,gbtree,0.3,0.05,1.2,10,8,reg:squarederror,0,0.75,222.865009,5.886806,4.799471,0.317601,1.806649,0.163612,54.031316,7.21501,38.507441,8.180487,0.816737,0.023736,0.094851,0.014953,5.175419,0.865341,3.173016,0.813581,74.850319,28.722612,78.907613,37.137734,0.690425,0.068098,0.134028,0.040303,6.356806,0.555927,2.168152,0.33729,93.522183,12.875038,57.475056,8.52658,0.852552,0.027043,0.055556,0.022326,8.882662,0.576821,2.445212,0.271917,177.586652,25.790957,94.95271,25.934651,0.355798,0.061111,0.159246,0.024003,7.832627,0.329817,4.338096,0.424096,176.82777,20.232387,183.677293,29.558129,0.200953,0.06109,0.200616,0.032392,13.993864,0.660465,5.73868,0.725697,488.027113,30.638904,463.16857,53.484364,0.326508,0.065353,0.192906,0.100099,5.443899,74.134606,0.786571,10.236385,280.813845,0.29442
6,X2M36O6DH3,gbtree,0.4,0.01,1.0,10,8,reg:squarederror,0,1.0,246.144911,9.487769,4.805576,0.420795,1.888233,0.55294,55.31228,10.555639,42.552025,19.419592,0.812687,0.028575,0.103424,0.019271,5.263627,0.999363,3.257158,0.992078,80.303304,37.496748,88.669983,50.904713,0.671332,0.077595,0.149706,0.024249,6.318958,0.756582,2.368404,0.647758,95.315985,19.246409,66.126261,18.80987,0.851391,0.037975,0.067848,0.040316,8.962973,0.601726,2.473893,0.337558,181.620747,27.603399,97.296354,27.294735,0.342515,0.066644,0.157687,0.024215,7.87894,0.338865,4.38586,0.408588,179.8684,19.895146,186.381375,28.62224,0.188431,0.072421,0.205924,0.030558,14.096669,0.650531,5.793156,0.747725,491.684721,35.222493,461.66695,54.803342,0.32056,0.055776,0.198558,0.094055,5.462721,76.97719,0.77847,10.312861,284.391289,0.283836
7,WM6A1S4U3W,gbtree,0.2,0.01,1.2,4,8,reg:squarederror,0,0.75,89.252449,0.586679,8.535054,0.163454,2.614943,0.174418,156.157955,6.681633,94.350302,10.233762,0.477604,0.009818,0.077367,0.01323,7.593009,0.138688,4.398026,0.144129,154.863842,6.030149,158.402068,9.308903,0.365257,0.006576,0.118008,0.009731,13.331772,0.188941,4.924855,0.074306,392.361869,10.26378,326.17704,12.938858,0.469968,0.01754,0.071766,0.015439,9.72579,0.588451,3.048632,0.514188,203.79382,23.473756,125.716649,35.29346,0.292105,0.041801,0.109736,0.030895,8.24372,0.272811,4.82511,0.396018,190.740192,16.817063,207.360925,21.85635,0.206255,0.040508,0.151919,0.020014,15.428264,0.950209,6.121568,0.722889,535.761671,36.073797,493.008236,74.399368,0.257141,0.067469,0.170609,0.103448,9.819945,234.461222,0.43761,11.132591,310.098561,0.251834
16,2OY5JYTSAA,gbtree,0.2,0.05,1.2,4,8,reg:squarederror,0,0.5,73.49493,0.686476,8.54379,0.176517,2.614839,0.196505,156.714192,8.570693,94.167861,12.373393,0.475837,0.014776,0.076227,0.013027,7.65149,0.2715,4.52607,0.382548,158.848581,16.397632,165.807,28.232596,0.362741,0.016076,0.115826,0.014334,13.346551,0.163964,4.935782,0.079257,394.345204,9.279819,329.984192,13.388057,0.468149,0.014711,0.072512,0.012871,9.732947,0.618317,3.060398,0.519671,204.484774,24.363311,127.068999,35.871054,0.291162,0.042252,0.111459,0.030928,8.217964,0.254311,4.800279,0.361329,189.872017,16.041574,205.919689,21.37759,0.20852,0.042578,0.153208,0.020353,15.455033,0.945722,6.131097,0.706949,538.299375,36.162961,497.161653,76.829372,0.255127,0.061373,0.16677,0.097774,9.847277,236.635992,0.435575,11.135315,310.885389,0.251603
13,2N9QBTGWJU,gbtree,0.3,0.05,1.0,15,8,reg:squarederror,0,0.5,275.384847,6.545323,3.459848,0.44545,1.832683,0.319483,31.310275,6.772241,29.51024,8.769019,0.892838,0.021645,0.088852,0.013496,4.090365,1.126235,2.654892,1.044323,53.360752,30.945827,61.122982,34.375581,0.782505,0.067404,0.138769,0.042244,3.870157,0.717757,1.562635,0.516094,36.872406,10.887439,25.149752,9.956462,0.932963,0.023515,0.045666,0.019317,9.056131,0.66009,2.390192,0.256132,185.101162,28.224753,94.992383,25.015662,0.322546,0.065652,0.171907,0.020217,8.040552,0.363688,4.384686,0.40042,186.05085,21.109903,191.967867,29.856292,0.146732,0.066032,0.221208,0.038781,14.302823,0.730138,5.801189,0.666819,516.291082,39.63336,474.386728,66.456475,0.28024,0.064659,0.213739,0.093874,3.80679,40.514478,0.869436,10.466502,295.814365,0.24984
8,UZN5JZ9SD9,gbtree,0.3,0.1,1.2,15,8,reg:squarederror,0,0.75,334.583869,8.885188,3.402077,0.534963,1.745105,0.19971,29.915026,7.326268,27.395353,7.789115,0.896996,0.024096,0.084987,0.013757,4.189589,1.052695,2.742377,0.917512,54.308792,26.914881,62.574729,29.690114,0.779416,0.06413,0.134298,0.037382,3.896318,0.988338,1.579005,0.619063,37.998527,16.268911,26.871014,14.804675,0.935482,0.026419,0.042828,0.01631,9.026674,0.666639,2.374402,0.240099,184.785738,27.34288,95.330261,23.865734,0.32396,0.066524,0.172123,0.016996,8.044724,0.358262,4.380703,0.393833,187.533423,22.543142,194.780314,32.699635,0.141087,0.069707,0.225941,0.042268,14.298267,0.702281,5.854042,0.664065,519.244815,34.957753,484.5292,64.404344,0.279349,0.060658,0.214174,0.094308,3.829328,40.740782,0.870631,10.456555,297.187992,0.248132
0,P63LB1GSD7,gbtree,0.2,0.01,1.0,15,8,reg:squarederror,0,1.0,407.360964,7.571134,3.929873,0.295803,1.649186,0.178539,35.166832,5.004749,28.175675,7.174984,0.880481,0.016705,0.08071,0.010682,4.359995,1.050321,2.769038,1.00063,58.059231,28.185341,65.648273,33.923707,0.773871,0.071547,0.130074,0.057365,5.539042,0.383757,2.128325,0.407765,62.455117,8.003072,39.499846,9.401047,0.903713,0.0146,0.038891,0.019795,9.135363,0.661232,2.56551,0.339938,189.357899,26.744436,103.001402,22.880423,0.314785,0.061079,0.165634,0.019312,8.045571,0.38839,4.379696,0.3817,188.01716,22.946639,195.167502,32.351094,0.141455,0.073916,0.222901,0.038939,14.634289,0.930613,6.008883,0.778121,535.661219,37.497986,492.182347,55.38876,0.257017,0.067955,0.216929,0.108998,4.609637,51.893726,0.852688,10.605075,304.345426,0.237752


In [7]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'booster': 'gbtree', 'eta': 0.3, 'gamma': 0.01, 'lambda': 1.0, 'max_depth': 6, 'nthread': 8, 'objective': 'reg:squarederror', 'seed': 0, 'subsample': 0.75}


## Best model

In [8]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_train_{}.npy'.format(DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_test_{}.npy'.format(DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_train_{}.npy'.format(DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_test_{}.npy'.format(DATA_ID)))

In [9]:
results = defaultdict(list)
tr_time = []
for target in range(Y_train.shape[1]):

    dtrain = xgb.DMatrix(data=X_train, label=Y_train[:, target])
    dtest = xgb.DMatrix(data=X_test, label=Y_test[:, target])

    callbacks = [xgb.callback.EarlyStopping(rounds=5, metric_name='rmse', maximize=False, save_best=True)]
    
    t_start = time.time()
    model = xgb.train(best_params, dtrain, evals=[(dtest, 'rmse')], callbacks=callbacks, verbose_eval=False)
    tr_time.append(time.time() - t_start)
    
    # Save the model
    model.save_model(os.path.join(RESULTS_PATH, DATA_ID, 'XGB_{}'.format(HS_DATE), 'XGB_best_model_{}_{}_{}.joblib'.format(target, HS_DATE, DATA_ID)))
    
    train_preds = model.predict(dtrain)
    test_preds = model.predict(dtest)

    results['Train_MAE'].append(mean_absolute_error(Y_train[:, target], train_preds))
    results['Train_MSE'].append(mean_squared_error(Y_train[:, target], train_preds))
    results['Train_R2'].append(r2_score(Y_train[:, target], train_preds))
    results['Test_MAE'].append(mean_absolute_error(Y_test[:, target], test_preds))
    results['Test_MSE'].append(mean_squared_error(Y_test[:, target], test_preds))
    results['Test_R2'].append(r2_score(Y_test[:, target], test_preds))

print('Training time: {:.4f}'.format(sum(tr_time)))

Training time: 320.8449


In [10]:
# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy', 'Fz']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results['_'.join([subset, loss])][i + f] for i in range(0, N_CELLS * 3, 3)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

Train Fx MAE: 7.1869 ± 1.6493
Train Fx MSE: 107.9562 ± 46.7758
Train Fx R2: 0.6132 ± 0.0684
Train Fy MAE: 6.6962 ± 3.5113
Train Fy MSE: 110.3752 ± 95.1319
Train Fy R2: 0.4701 ± 0.1710
Train Fz MAE: 10.4374 ± 3.4805
Train Fz MSE: 239.5518 ± 179.0259
Train Fz R2: 0.6489 ± 0.0732
Test Fx MAE: 10.9776 ± 3.8296
Test Fx MSE: 260.9678 ± 172.2369
Test Fx R2: 0.2024 ± 0.1711
Test Fy MAE: 9.8136 ± 6.3831
Test Fy MSE: 322.7382 ± 396.1981
Test Fy R2: 0.0787 ± 0.1623
Test Fz MAE: 16.6790 ± 7.2368
Test Fz MSE: 666.6956 ± 668.7742
Test Fz R2: 0.1651 ± 0.1559


In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()