In [27]:
import pandas as pd
import numpy as np
import os
import glob
import json
import time
from collections import defaultdict
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import xgboost as xgb

## Configuration

In [2]:
# Number of force cells in the robotic leg
N_CELLS = 8

# Path where the results are stored
RESULTS_PATH = '../../../results'
# ID of the training and test data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0004_15042021'
# Hyperparameters search date
HS_DATE = '15042021'
# Number of folds in cross-validation
CV = 6

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0004_15042021


## Hyperparameters seach analysis

In [18]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, 'XGB_{}'.format(HS_DATE), 'XGB_{}_*.json'.format(HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 486


In [19]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_booster,param_eta,param_gamma,param_lambda,param_max_depth,param_nthread,param_objective,param_seed,param_subsample,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Test_Fx_MAE_mean__mean,Test_Fx_MAE_mean__std,Test_Fx_MAE_std__mean,Test_Fx_MAE_std__std,Test_Fx_MSE_mean__mean,Test_Fx_MSE_mean__std,Test_Fx_MSE_std__mean,Test_Fx_MSE_std__std,Test_Fx_R2_mean__mean,Test_Fx_R2_mean__std,Test_Fx_R2_std__mean,Test_Fx_R2_std__std,Test_Fy_MAE_mean__mean,Test_Fy_MAE_mean__std,Test_Fy_MAE_std__mean,Test_Fy_MAE_std__std,Test_Fy_MSE_mean__mean,Test_Fy_MSE_mean__std,Test_Fy_MSE_std__mean,Test_Fy_MSE_std__std,Test_Fy_R2_mean__mean,Test_Fy_R2_mean__std,Test_Fy_R2_std__mean,Test_Fy_R2_std__std,Test_Fz_MAE_mean__mean,Test_Fz_MAE_mean__std,Test_Fz_MAE_std__mean,Test_Fz_MAE_std__std,Test_Fz_MSE_mean__mean,Test_Fz_MSE_mean__std,Test_Fz_MSE_std__mean,Test_Fz_MSE_std__std,Test_Fz_R2_mean__mean,Test_Fz_R2_mean__std,Test_Fz_R2_std__mean,Test_Fz_R2_std__std
0,PTEHVO2QK2,gbtree,0.3,0.10,1.2,12,8,reg:squarederror,0,0.50,20.806981,1.300716,5.828721,2.136396,3.286920,1.993901,85.646589,53.755825,87.850878,79.804586,0.761662,0.124334,0.194398,0.117103,7.352057,1.538535,6.037452,1.380673,184.165590,69.376084,248.667178,95.489092,0.621482,0.130884,0.181017,0.047947,6.712153,3.118217,3.379021,2.111139,102.311511,86.018884,90.420212,94.483530,0.798421,0.156355,0.192776,0.178722,12.126970,1.637808,4.012353,1.039281,310.874458,82.935998,209.782153,98.922784,-0.072073,0.380281,0.312984,0.223286,11.403196,1.550200,8.376759,1.883289,425.733335,111.011080,630.313615,171.386303,-0.073392,0.116810,0.261770,0.101672,16.126198,2.805009,6.517967,1.687439,621.968472,185.555494,581.575327,281.867730,-0.322329,0.547673,0.516527,0.196665
1,GOSEIOZHXE,gbtree,0.2,0.01,1.2,12,8,reg:squarederror,0,0.50,21.770642,1.148223,6.401059,1.743584,3.026555,1.196900,88.854627,43.929458,75.520205,45.144954,0.750736,0.112639,0.168578,0.097841,7.423994,1.595177,6.057679,1.760923,189.963198,88.618753,263.647299,138.762248,0.606946,0.142093,0.216146,0.055919,8.358155,2.304446,3.608157,1.885310,130.517789,72.982303,100.868721,92.081807,0.752268,0.150327,0.206507,0.198816,12.235199,1.485155,4.281006,0.938203,312.228844,74.970537,218.321190,88.823604,-0.072885,0.366872,0.325108,0.220172,11.250470,1.586538,8.324505,1.924200,417.348033,110.528910,628.799949,167.527628,-0.038149,0.107098,0.231810,0.073191,16.584548,2.835792,6.845719,1.453005,648.212614,185.035502,618.586435,307.224457,-0.371530,0.541521,0.559801,0.236390
2,BU8XOLH8WK,gbtree,0.4,0.01,0.8,6,8,reg:squarederror,0,0.50,9.952336,0.566039,7.723353,1.604774,3.391554,1.965806,131.527493,55.185971,111.501712,92.260395,0.647833,0.112687,0.184898,0.095764,8.207370,0.765423,6.064603,0.719504,208.226473,34.601198,261.998117,46.700908,0.539282,0.079351,0.167955,0.030269,9.089048,2.050127,3.987035,2.084723,178.702493,79.996708,134.036324,104.320017,0.666784,0.148610,0.249400,0.192878,11.967705,1.714449,3.928991,1.161412,297.127454,79.650819,194.414553,86.228322,-0.026313,0.351857,0.316866,0.193391,11.203747,1.631814,8.276788,1.677309,413.655956,110.204796,608.670352,163.582742,-0.046018,0.170482,0.311183,0.292065,15.970900,3.174752,6.421824,1.778102,610.575589,217.775354,582.730912,293.139899,-0.273830,0.539141,0.454634,0.202566
3,1ADQP2NY15,gbtree,0.3,0.05,1.2,10,8,reg:squarederror,0,0.50,17.527234,0.945707,6.242355,1.948542,3.266454,1.819853,93.502180,53.629476,90.078179,79.729932,0.744939,0.115966,0.193205,0.105945,7.415097,1.052784,6.146893,1.005797,183.627867,53.223411,257.337945,77.580384,0.622136,0.093760,0.210606,0.021614,7.012989,2.987668,3.312925,2.091790,111.244453,84.846076,92.351347,97.216478,0.779193,0.156117,0.196047,0.185016,12.007544,1.623943,4.078716,1.118971,303.484047,81.421972,204.349860,94.829554,-0.045869,0.370028,0.309734,0.216639,11.279912,1.595664,8.293362,1.891606,417.477715,110.406620,615.476615,169.986781,-0.059164,0.113205,0.263232,0.109414,16.020148,2.998667,6.542361,1.716901,609.067375,183.693405,566.598407,265.696503,-0.286058,0.520657,0.475474,0.189964
4,7O6RPGDDBF,gbtree,0.2,0.05,1.0,6,8,reg:squarederror,0,0.50,11.046924,0.463690,8.073818,1.459242,3.337297,1.570721,138.296385,51.334992,109.089655,72.897853,0.626361,0.115636,0.167106,0.097586,8.467709,1.141520,6.451255,1.421944,228.828408,78.050407,302.523016,129.805178,0.527566,0.097576,0.178297,0.025420,10.780540,1.757144,4.405462,1.989135,231.133547,76.238212,169.990651,110.107562,0.594216,0.140503,0.246956,0.208427,12.115818,1.500073,4.442295,0.927124,303.099529,66.243335,215.890372,68.320494,-0.029318,0.332895,0.329632,0.184366,10.956519,1.677706,8.150386,1.846448,398.591054,106.762539,604.206735,158.227219,0.024114,0.102701,0.203632,0.080580,16.739249,3.180753,7.132251,1.833448,658.863823,225.910089,662.079525,346.729069,-0.343928,0.546775,0.529955,0.273028
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
481,Y6UL7OBUTN,gbtree,0.4,0.05,1.0,10,8,reg:squarederror,0,0.50,16.260606,1.137723,6.221831,2.115246,3.224871,1.939068,93.533124,57.444352,89.968089,86.034613,0.749880,0.116931,0.166193,0.081273,7.156961,0.776598,5.638245,0.994734,161.665493,37.026585,212.798024,57.132748,0.629882,0.075585,0.187533,0.027033,6.959745,3.133339,3.586521,2.107294,113.206072,87.823144,97.866344,98.487188,0.775326,0.171450,0.198710,0.186486,12.102256,1.639959,4.002641,1.094372,310.521772,82.148726,209.311911,99.064196,-0.075249,0.383243,0.325106,0.230087,11.500233,1.611367,8.466104,1.926509,433.645089,113.223789,636.893168,176.461304,-0.098571,0.144419,0.315540,0.202221,16.003299,2.977157,6.387011,1.677986,613.975688,184.777018,564.664817,267.359801,-0.316299,0.552153,0.497891,0.199236
482,5EUU7AY59P,gbtree,0.2,0.10,1.0,10,8,reg:squarederror,0,0.75,22.818095,1.072744,6.627446,1.766148,3.137652,1.624666,97.016690,50.427875,86.530552,71.318450,0.738127,0.118094,0.158699,0.105932,7.637002,1.392165,6.152426,1.590711,196.863659,78.357723,276.593704,130.204602,0.595609,0.127920,0.208561,0.042478,8.810085,2.288327,3.908270,1.816126,147.685830,75.354616,118.444710,89.809683,0.732735,0.148333,0.206003,0.198567,12.116313,1.463629,4.326859,1.008106,308.762003,75.496795,218.927818,91.028054,-0.058359,0.361232,0.331607,0.208197,11.174996,1.644167,8.344849,1.974229,414.231507,109.687420,625.076807,160.544046,-0.019719,0.103972,0.221069,0.079828,16.451323,2.790860,6.817917,1.507397,639.411742,184.418234,614.610324,302.671471,-0.338299,0.513553,0.525617,0.227222
483,KJS0JVJA94,gbtree,0.3,0.05,0.8,6,8,reg:squarederror,0,0.75,13.045944,0.716725,7.516665,1.404260,2.992380,1.210878,121.422340,42.566940,86.531313,46.458095,0.654035,0.120932,0.184306,0.104216,8.424110,1.054124,6.255297,1.025339,220.680528,56.798176,274.753100,88.309859,0.528183,0.110432,0.151582,0.028877,9.315718,2.158998,3.938288,1.731642,181.820691,76.384524,137.228579,96.142147,0.658186,0.156791,0.243360,0.182176,11.833816,1.616283,4.079092,0.943672,293.235080,74.687962,197.490821,74.834495,-0.006138,0.348754,0.318071,0.183108,11.028110,1.597278,8.172068,1.776510,402.519140,105.945837,601.450002,161.392748,-0.007113,0.125989,0.239990,0.147813,16.085773,3.127274,6.810829,1.792903,622.526026,222.939653,620.624274,336.155643,-0.254270,0.493377,0.442351,0.185440
484,7FQUUUNV09,gbtree,0.4,0.10,0.8,12,8,reg:squarederror,0,0.50,19.365904,1.281481,6.115299,2.216370,3.238999,1.947204,89.731155,58.575944,86.115971,83.123323,0.757168,0.120231,0.167050,0.085494,7.002431,1.005637,5.528305,1.106421,157.875977,42.543086,200.176098,64.471690,0.663257,0.085221,0.157370,0.031801,6.667847,3.131423,3.440370,2.073518,101.617036,82.970335,90.289007,95.667402,0.796964,0.157719,0.184671,0.182608,12.169673,1.598980,3.970829,1.034071,310.867317,82.468457,204.298348,94.761449,-0.085984,0.385311,0.326014,0.220836,11.576177,1.535544,8.479662,1.835158,435.635059,111.606921,632.921508,174.367498,-0.126443,0.146119,0.332786,0.204321,16.068501,2.935469,6.529112,1.705351,623.686119,179.836506,587.924929,266.974839,-0.326785,0.537530,0.500947,0.186980


In [20]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Test']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy', 'Fz']]].mean(axis=1)

In [21]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Test_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_booster,param_eta,param_gamma,param_lambda,param_max_depth,param_nthread,param_objective,param_seed,param_subsample,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Test_Fx_MAE_mean__mean,Test_Fx_MAE_mean__std,Test_Fx_MAE_std__mean,Test_Fx_MAE_std__std,Test_Fx_MSE_mean__mean,Test_Fx_MSE_mean__std,Test_Fx_MSE_std__mean,Test_Fx_MSE_std__std,Test_Fx_R2_mean__mean,Test_Fx_R2_mean__std,Test_Fx_R2_std__mean,Test_Fx_R2_std__std,Test_Fy_MAE_mean__mean,Test_Fy_MAE_mean__std,Test_Fy_MAE_std__mean,Test_Fy_MAE_std__std,Test_Fy_MSE_mean__mean,Test_Fy_MSE_mean__std,Test_Fy_MSE_std__mean,Test_Fy_MSE_std__std,Test_Fy_R2_mean__mean,Test_Fy_R2_mean__std,Test_Fy_R2_std__mean,Test_Fy_R2_std__std,Test_Fz_MAE_mean__mean,Test_Fz_MAE_mean__std,Test_Fz_MAE_std__mean,Test_Fz_MAE_std__std,Test_Fz_MSE_mean__mean,Test_Fz_MSE_mean__std,Test_Fz_MSE_std__mean,Test_Fz_MSE_std__std,Test_Fz_R2_mean__mean,Test_Fz_R2_mean__std,Test_Fz_R2_std__mean,Test_Fz_R2_std__std,Train_MAE,Train_MSE,Train_R2,Test_MAE,Test_MSE,Test_R2
134,5X14K8IO8U,gbtree,0.3,0.05,1.2,4,8,reg:squarederror,0,1.0,10.443708,0.571809,9.182725,1.250873,3.513662,1.619213,175.836141,52.429799,132.692157,87.041316,0.518545,0.125096,0.189093,0.111463,8.866087,0.806719,6.398834,0.909766,240.204087,53.988585,300.534713,90.418312,0.459523,0.094961,0.143273,0.036745,11.553430,2.043508,4.273249,1.843870,279.158443,103.332326,198.267526,131.422446,0.521865,0.175526,0.230697,0.175338,12.050853,1.646980,4.213353,1.134222,298.277726,72.609813,206.296498,72.757799,-0.017427,0.342344,0.328807,0.179398,10.840789,1.614398,8.160346,1.685704,391.131379,107.284721,594.694869,161.721196,0.054778,0.133726,0.203987,0.109812,15.902038,3.195682,6.643298,1.860052,593.940002,220.178522,579.110371,291.846989,-0.207729,0.493957,0.425936,0.198426,9.867414,231.732890,0.499978,12.931227,427.783036,-0.056792
341,NDOLFR1ZJT,gbtree,0.3,0.10,1.2,4,8,reg:squarederror,0,1.0,10.403922,0.562595,9.182725,1.250873,3.513662,1.619213,175.836141,52.429799,132.692157,87.041316,0.518545,0.125096,0.189093,0.111463,8.866087,0.806719,6.398834,0.909766,240.204087,53.988585,300.534713,90.418312,0.459523,0.094961,0.143273,0.036745,11.553430,2.043508,4.273249,1.843870,279.158443,103.332326,198.267526,131.422446,0.521865,0.175526,0.230697,0.175338,12.050853,1.646980,4.213353,1.134222,298.277726,72.609813,206.296498,72.757799,-0.017427,0.342344,0.328807,0.179398,10.840789,1.614398,8.160346,1.685704,391.131379,107.284721,594.694869,161.721196,0.054778,0.133726,0.203987,0.109812,15.902038,3.195682,6.643298,1.860052,593.940002,220.178522,579.110371,291.846989,-0.207729,0.493957,0.425936,0.198426,9.867414,231.732890,0.499978,12.931227,427.783036,-0.056792
152,FVBCE8WIBL,gbtree,0.3,0.01,1.2,4,8,reg:squarederror,0,1.0,10.428183,0.560162,9.182725,1.250873,3.513662,1.619213,175.836141,52.429799,132.692157,87.041316,0.518545,0.125096,0.189093,0.111463,8.866087,0.806719,6.398834,0.909766,240.204087,53.988585,300.534713,90.418312,0.459523,0.094961,0.143273,0.036745,11.553430,2.043508,4.273249,1.843870,279.158443,103.332326,198.267526,131.422446,0.521865,0.175526,0.230697,0.175338,12.050853,1.646980,4.213353,1.134222,298.277726,72.609813,206.296498,72.757799,-0.017427,0.342344,0.328807,0.179398,10.840789,1.614398,8.160346,1.685704,391.131379,107.284721,594.694869,161.721196,0.054778,0.133726,0.203987,0.109812,15.902038,3.195682,6.643298,1.860052,593.940002,220.178522,579.110371,291.846989,-0.207729,0.493957,0.425936,0.198426,9.867414,231.732890,0.499978,12.931227,427.783036,-0.056792
350,RNR7IFGWXB,gbtree,0.3,0.01,0.8,4,8,reg:squarederror,0,1.0,10.436562,0.565319,9.149469,1.198446,3.379535,1.657418,173.526864,51.699550,124.814084,89.832841,0.519477,0.121280,0.190747,0.109330,8.776314,0.838054,6.327865,0.901184,232.819290,57.824927,292.045713,95.504876,0.467422,0.101924,0.141492,0.036820,11.701177,2.152776,4.625909,2.312728,291.843914,115.712907,232.184958,184.136517,0.517608,0.175443,0.236932,0.171125,11.989807,1.601269,4.095961,0.984322,296.671214,71.773495,201.428004,68.587105,-0.017585,0.341727,0.329873,0.179942,10.816380,1.663747,8.096410,1.765174,390.013930,109.110688,591.633748,165.971690,0.053571,0.130922,0.205355,0.109542,15.930527,3.203751,6.642326,1.895106,593.930468,220.388708,577.092352,292.984185,-0.209976,0.493622,0.427296,0.199271,9.875653,232.730023,0.501503,12.912238,426.871870,-0.057997
445,0MA1JITCX2,gbtree,0.3,0.05,0.8,4,8,reg:squarederror,0,1.0,10.454099,0.561483,9.149469,1.198446,3.379535,1.657418,173.526864,51.699550,124.814084,89.832841,0.519477,0.121280,0.190747,0.109330,8.776314,0.838054,6.327865,0.901184,232.819290,57.824927,292.045713,95.504876,0.467422,0.101924,0.141492,0.036820,11.701177,2.152776,4.625909,2.312728,291.843914,115.712907,232.184958,184.136517,0.517608,0.175443,0.236932,0.171125,11.989807,1.601269,4.095961,0.984322,296.671214,71.773495,201.428004,68.587105,-0.017585,0.341727,0.329873,0.179942,10.816380,1.663747,8.096410,1.765174,390.013930,109.110688,591.633748,165.971690,0.053571,0.130922,0.205355,0.109542,15.930527,3.203751,6.642326,1.895106,593.930468,220.388708,577.092352,292.984185,-0.209976,0.493622,0.427296,0.199271,9.875653,232.730023,0.501503,12.912238,426.871870,-0.057997
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
398,G9K3AXCCS1,gbtree,0.4,0.01,1.0,30,8,reg:squarederror,0,1.0,55.629918,5.591363,5.378578,2.456097,3.665838,1.993798,77.132672,57.694155,89.272788,81.800539,0.800723,0.121645,0.164707,0.091997,6.632500,1.125059,5.569060,0.904298,154.268427,38.831498,212.452155,53.370244,0.682853,0.094205,0.133409,0.046351,5.900755,3.665015,4.004751,1.961302,87.134252,85.243691,91.548875,92.776228,0.816332,0.167661,0.185321,0.175512,12.468410,1.694300,4.228995,1.209090,327.482082,92.902786,219.799998,111.612383,-0.135436,0.429110,0.319415,0.226512,11.947146,1.541870,8.691975,1.926145,453.735387,113.245615,655.932218,181.466689,-0.193801,0.162702,0.380160,0.215855,16.533367,2.775588,6.494799,1.705022,653.882062,181.159724,605.986814,289.947220,-0.427178,0.572318,0.581577,0.234995,5.970611,106.178450,0.766636,13.649641,478.366510,-0.252138
220,OH15MGUTO6,gbtree,0.4,0.05,1.0,30,8,reg:squarederror,0,1.0,56.179735,5.505585,5.368158,2.453372,3.674348,2.004711,77.080902,57.688848,89.315950,81.804617,0.800845,0.121570,0.164837,0.092119,6.594474,1.081602,5.609879,0.915855,154.076945,38.622984,212.604156,53.450422,0.687544,0.090793,0.145012,0.035721,5.918274,3.640722,4.008709,1.962345,87.235366,85.131822,91.554755,92.786009,0.816010,0.167346,0.185240,0.175583,12.472350,1.692908,4.228281,1.208439,327.551549,92.852854,219.922621,111.515044,-0.135579,0.429139,0.319239,0.226404,11.949250,1.541745,8.692653,1.925190,453.838691,113.177888,656.126902,181.356744,-0.193962,0.162865,0.380199,0.216002,16.530540,2.775659,6.493029,1.718795,653.778944,181.187538,606.074882,290.642063,-0.427893,0.572354,0.581638,0.232833,5.960302,106.131071,0.768133,13.650713,478.389728,-0.252478
354,6IG4P5Q1CM,gbtree,0.4,0.10,0.8,30,8,reg:squarederror,0,1.0,56.882224,5.590676,5.127607,2.051106,2.961327,1.111086,62.981609,38.335155,54.700087,30.206285,0.816805,0.097766,0.145160,0.073576,6.526676,1.136151,5.602231,0.912087,152.098373,39.751435,211.855955,53.046211,0.690688,0.091159,0.145302,0.035868,5.771632,3.741366,3.965306,1.947850,85.313582,90.460359,86.424205,91.661139,0.822891,0.165219,0.184332,0.174953,12.473902,1.766873,4.160125,1.303519,326.868731,93.748033,216.097516,111.440201,-0.137898,0.435862,0.320769,0.225191,11.957887,1.547079,8.682207,1.917307,454.652518,114.141751,654.687072,182.187478,-0.200886,0.159835,0.394174,0.206004,16.586760,2.737333,6.523918,1.745968,652.645265,183.684729,598.879253,285.208231,-0.426859,0.555490,0.575026,0.207287,5.808638,100.131188,0.776795,13.672850,478.055505,-0.255214
352,JDYRSE9R04,gbtree,0.4,0.05,0.8,30,8,reg:squarederror,0,1.0,56.400979,5.424907,5.166422,2.000530,2.970714,1.100079,63.490915,37.707142,55.177829,29.478428,0.815519,0.096025,0.145479,0.073049,6.558581,1.090117,5.589913,0.913838,152.408979,39.350264,211.707276,53.154890,0.684123,0.085500,0.153113,0.040670,5.796873,3.726207,3.957739,1.959388,85.535566,90.316970,86.315340,91.729977,0.821773,0.164487,0.184222,0.175036,12.475075,1.763191,4.167324,1.299125,326.967023,93.628940,216.179237,111.461321,-0.137970,0.435584,0.320527,0.225316,11.956136,1.544242,8.682557,1.921844,454.614788,114.157432,654.560228,182.341094,-0.200947,0.159902,0.394000,0.206123,16.588474,2.743659,6.532037,1.747441,654.192726,185.846791,602.337658,290.688727,-0.428449,0.558721,0.577184,0.211079,5.840626,100.478486,0.773805,13.673229,478.591512,-0.255788


In [22]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'booster': 'gbtree', 'eta': 0.3, 'gamma': 0.05, 'lambda': 1.2, 'max_depth': 4, 'nthread': 8, 'objective': 'reg:squarederror', 'seed': 0, 'subsample': 1.0}


## Best model

In [23]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_train_{}.npy'.format(DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_test_{}.npy'.format(DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_train_{}.npy'.format(DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_test_{}.npy'.format(DATA_ID)))

In [29]:
results = defaultdict(list)
tr_time = []
for target in range(Y_train.shape[1]):

    dtrain = xgb.DMatrix(data=X_train, label=Y_train[:, target])
    dtest = xgb.DMatrix(data=X_test, label=Y_test[:, target])

    callbacks = [xgb.callback.EarlyStopping(rounds=5, metric_name='rmse', maximize=False, save_best=True)]
    
    t_start = time.time()
    model = xgb.train(best_params, dtrain, evals=[(dtest, 'rmse')], callbacks=callbacks, verbose_eval=False)
    tr_time.append(time.time() - t_start)
    
    # Save the model
    model.save_model(os.path.join(RESULTS_PATH, DATA_ID, 'XGB_{}'.format(HS_DATE), 'XGB_best_model_{}_{}_{}.joblib'.format(target, HS_DATE, DATA_ID)))
    
    train_preds = model.predict(dtrain)
    test_preds = model.predict(dtest)

    results['Train_MAE'].append(mean_absolute_error(Y_train[:, target], train_preds))
    results['Train_MSE'].append(mean_squared_error(Y_train[:, target], train_preds))
    results['Train_R2'].append(r2_score(Y_train[:, target], train_preds))
    results['Test_MAE'].append(mean_absolute_error(Y_test[:, target], test_preds))
    results['Test_MSE'].append(mean_squared_error(Y_test[:, target], test_preds))
    results['Test_R2'].append(r2_score(Y_test[:, target], test_preds))

print('Training time: {:.4f}'.format(sum(tr_time)))

Training time: 13.5001


In [30]:
# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy', 'Fz']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results['_'.join([subset, loss])][i + f] for i in range(0, N_CELLS * 3, 3)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

Train Fx MAE: 8.3961 ± 3.1096
Train Fx MSE: 141.7054 ± 99.0585
Train Fx R2: 0.5848 ± 0.2275
Train Fy MAE: 7.6546 ± 5.2656
Train Fy MSE: 162.0622 ± 172.9832
Train Fy R2: 0.5838 ± 0.1167
Train Fz MAE: 11.3034 ± 5.3325
Train Fz MSE: 275.4277 ± 219.0808
Train Fz R2: 0.5517 ± 0.3086
Test Fx MAE: 12.0839 ± 4.9598
Test Fx MSE: 289.8891 ± 222.9538
Test Fx R2: 0.4305 ± 0.1368
Test Fy MAE: 10.4137 ± 7.5679
Test Fy MSE: 357.7774 ± 463.0753
Test Fy R2: 0.3818 ± 0.1306
Test Fz MAE: 17.6703 ± 6.5851
Test Fz MSE: 668.5766 ± 703.9588
Test Fz R2: 0.4353 ± 0.1677


In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()