In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
import time
from collections import defaultdict
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import xgboost as xgb

## Configuration

In [2]:
JOINT = 'Ankle'
FORCE_CELLS_PER_JOINT = {
    'Hip': [5, 6],
    'Knee': [3, 4, 7, 8],
    'Ankle': [1, 2]
}

CELLS = FORCE_CELLS_PER_JOINT[JOINT]

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0013_09082021'
# Hyperparameters search date
HS_DATE = '11082021'
# Number of folds in cross-validation
CV = 4

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0013_09082021


## Hyperparameters seach analysis

In [3]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, '{}_XGB_{}'.format(JOINT, HS_DATE), '{}_XGB_{}_*.json'.format(JOINT, HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 486


In [4]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_booster,param_eta,param_gamma,param_lambda,param_max_depth,param_nthread,param_objective,param_seed,param_subsample,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std
0,J87OR9NMCB,gbtree,0.5,0.005,2,6,8,reg:squarederror,0,1.00,0.274100,0.014022,1.935000,0.380108,0.248563,0.241426,7.600142,3.847475,2.454426,2.782151,0.965182,0.014826,0.010303,0.011048,2.176885,0.399206,0.354464,0.248946,10.298374,4.413245,4.105782,3.843737,0.832137,0.051587,0.028025,0.022266,3.296834,0.293318,0.217370,0.118375,22.205767,6.246621,2.977228,1.713344,0.884052,0.027638,0.016273,0.008900,3.414697,0.141686,0.416298,0.218884,24.038563,1.777784,6.813392,4.088177,0.547373,0.111716,0.034247,0.031031
1,02E0FIH400,gbtree,0.4,0.005,2,7,8,reg:squarederror,0,0.75,0.294893,0.017705,1.784287,0.407626,0.266944,0.281199,6.820922,4.075547,2.566358,3.256424,0.968927,0.015976,0.010689,0.013166,1.896295,0.316412,0.233273,0.245348,7.871526,3.197680,2.705567,2.915558,0.869227,0.036026,0.021336,0.014575,3.275961,0.291628,0.166524,0.089558,21.775495,5.762477,2.091370,0.785968,0.886098,0.027242,0.012224,0.003951,3.286921,0.063103,0.471967,0.116575,23.334020,0.904992,8.000352,2.317055,0.565706,0.103315,0.011778,0.006232
2,09HPD3UF2Y,gbtree,0.5,0.010,1,5,8,reg:squarederror,0,0.50,0.167201,0.009177,2.381142,0.355890,0.246360,0.191129,11.085970,4.336854,2.846948,2.765539,0.948793,0.016234,0.011894,0.010748,2.666578,0.381613,0.255683,0.259450,14.410042,4.953912,4.226054,4.023038,0.753792,0.060181,0.027460,0.016337,3.371391,0.249725,0.292619,0.176773,21.675448,4.721145,3.583057,1.734353,0.883729,0.028832,0.025633,0.008972,3.485854,0.091624,0.445693,0.281053,24.654611,2.258763,7.665691,5.188128,0.541488,0.116704,0.045531,0.017918
3,09UWD00FXL,gbtree,0.3,0.050,0,5,8,reg:squarederror,0,1.00,0.237553,0.005409,2.355981,0.253968,0.248197,0.045046,10.142151,2.635516,2.275832,0.809438,0.952639,0.009649,0.009852,0.002479,2.449818,0.278350,0.286778,0.258361,12.176525,3.525806,4.204276,3.450737,0.795868,0.036029,0.022835,0.014600,3.360082,0.271281,0.247433,0.094116,21.281348,4.488204,2.884348,0.931000,0.885783,0.031807,0.016463,0.006540,3.408806,0.097913,0.457016,0.167995,23.299031,1.140611,7.503887,3.376179,0.566439,0.102428,0.015930,0.005532
4,0B26HU5OLB,gbtree,0.4,0.010,1,5,8,reg:squarederror,0,0.50,0.178279,0.008625,2.403248,0.386226,0.271298,0.163781,11.284529,4.617709,2.848917,2.671136,0.947817,0.017638,0.011921,0.010278,2.691386,0.462551,0.404626,0.316634,15.320716,6.294004,6.340039,5.257325,0.751604,0.072341,0.033791,0.026970,3.317987,0.215031,0.282287,0.056190,20.794666,3.862026,3.293423,0.659551,0.888017,0.029451,0.019603,0.004266,3.497327,0.149154,0.437717,0.210238,24.047186,1.765089,7.179553,3.691956,0.549998,0.105977,0.021657,0.018154
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
481,THLW20JLFS,gbtree,0.3,0.050,0,8,8,reg:squarederror,0,0.50,0.327310,0.021255,1.820770,0.525517,0.324124,0.319662,6.810414,4.577185,2.767866,3.421321,0.969169,0.018132,0.011622,0.013765,1.931559,0.477516,0.321785,0.312629,8.375477,4.917911,3.581263,4.147151,0.866027,0.058971,0.034237,0.017079,3.287577,0.272863,0.215407,0.114588,21.491795,4.849515,2.489132,1.512651,0.886317,0.028408,0.013125,0.007344,3.286517,0.141764,0.455104,0.130998,22.422067,1.763458,7.419437,2.740692,0.583303,0.091454,0.018539,0.010297
482,V1REW5G3PE,gbtree,0.5,0.050,1,7,8,reg:squarederror,0,0.75,0.267035,0.028548,1.758036,0.352838,0.320503,0.246683,6.526281,3.421436,2.634955,2.694962,0.970043,0.013374,0.011243,0.010714,2.041322,0.302800,0.273311,0.253744,8.879834,3.168179,3.000007,3.095985,0.850406,0.032652,0.027332,0.015306,3.253349,0.278459,0.166590,0.058543,21.891278,5.202392,2.195719,0.633468,0.884692,0.025497,0.015502,0.009071,3.326746,0.076549,0.462082,0.172336,23.575410,0.529270,7.898587,3.088415,0.562124,0.104348,0.013586,0.010392
483,VZQB390I7E,gbtree,0.5,0.050,1,5,8,reg:squarederror,0,0.50,0.174209,0.018972,2.381142,0.355890,0.246360,0.191129,11.085970,4.336854,2.846948,2.765539,0.948793,0.016234,0.011894,0.010748,2.666571,0.381617,0.255689,0.259444,14.409989,4.953944,4.226001,4.023083,0.753793,0.060181,0.027461,0.016338,3.371391,0.249725,0.292619,0.176773,21.675448,4.721145,3.583057,1.734353,0.883729,0.028832,0.025633,0.008972,3.485857,0.091629,0.445697,0.281058,24.654637,2.258800,7.665717,5.188168,0.541488,0.116704,0.045531,0.017918
484,X1QLH1R59Q,gbtree,0.3,0.050,1,3,8,reg:squarederror,0,0.50,0.122280,0.010580,3.305067,0.390091,0.311659,0.110298,20.307858,6.842219,4.623384,3.840024,0.905754,0.025273,0.019521,0.014464,3.392150,0.320176,0.549120,0.330305,23.301355,5.510039,9.178413,6.059815,0.615510,0.049441,0.032394,0.046678,3.840901,0.369874,0.371768,0.205486,26.687055,6.433363,3.840979,1.765833,0.859071,0.037626,0.018058,0.010948,3.760946,0.277601,0.625793,0.376683,27.212297,4.161432,10.427568,5.885328,0.518553,0.067122,0.031021,0.012142


In [5]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy']]].mean(axis=1)

In [6]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_booster,param_eta,param_gamma,param_lambda,param_max_depth,param_nthread,param_objective,param_seed,param_subsample,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Train_MAE,Train_MSE,Train_R2,Valid_MAE,Valid_MSE,Valid_R2
381,I604NHVSOG,gbtree,0.3,0.010,2,8,8,reg:squarederror,0,0.75,0.367147,0.010216,1.753175,0.332078,0.238874,0.212175,6.522944,3.092272,2.002674,2.511543,0.970055,0.011795,0.008280,0.010104,1.610473,0.303931,0.207624,0.175885,5.806591,2.448699,2.137215,1.711683,0.903901,0.031959,0.012527,0.005585,3.297708,0.317960,0.185314,0.081761,22.429417,6.208155,1.967976,1.193160,0.883937,0.025160,0.010686,0.004842,3.213125,0.091150,0.485465,0.137068,22.315447,0.902784,7.962823,3.156370,0.588788,0.100583,0.020016,0.010433,1.681824,6.164767,0.936978,3.255417,22.372432,0.736363
178,TMQBGWFH6C,gbtree,0.3,0.005,2,8,8,reg:squarederror,0,0.75,0.352341,0.011862,1.752739,0.331899,0.238781,0.212610,6.520035,3.092621,2.000764,2.513570,0.970070,0.011795,0.008270,0.010115,1.629819,0.286791,0.193695,0.188283,5.898313,2.368963,2.045481,1.795683,0.901523,0.030215,0.014905,0.003902,3.301017,0.319934,0.183181,0.086112,22.485735,6.268737,1.982505,1.184362,0.883719,0.025111,0.010835,0.004558,3.218280,0.092699,0.480400,0.128647,22.332971,0.926193,7.945409,3.130299,0.588363,0.100111,0.019592,0.009822,1.691279,6.209174,0.935797,3.259649,22.409353,0.736041
19,JR25RD2UQZ,gbtree,0.3,0.050,2,8,8,reg:squarederror,0,0.75,0.360766,0.018218,1.753237,0.331988,0.239002,0.212148,6.523037,3.092311,2.002767,2.511408,0.970055,0.011795,0.008281,0.010104,1.664006,0.300246,0.151800,0.155580,6.059049,2.434195,1.869491,1.689161,0.897615,0.030891,0.009106,0.004734,3.298242,0.318400,0.184901,0.082427,22.433097,6.211862,1.965324,1.198423,0.883923,0.025153,0.010699,0.004824,3.223262,0.088671,0.480015,0.144915,22.407359,0.953114,7.941579,3.236570,0.586800,0.100809,0.022051,0.010984,1.708621,6.291043,0.933835,3.260752,22.420228,0.735361
422,BY1YPWQC2K,gbtree,0.3,0.005,0,8,8,reg:squarederror,0,0.75,0.376870,0.024801,1.563606,0.285239,0.288064,0.163460,4.607945,1.975447,1.817790,1.384080,0.978756,0.007541,0.007928,0.005326,1.658962,0.430743,0.177791,0.173083,5.921927,3.322801,2.066148,1.998746,0.901907,0.045670,0.008328,0.003321,3.221743,0.266491,0.204980,0.101725,21.174285,5.019012,2.301934,1.420114,0.888767,0.028221,0.007101,0.006275,3.236597,0.145397,0.520073,0.160959,22.987785,2.050518,8.845000,3.202964,0.580940,0.094591,0.037300,0.016187,1.611284,5.264936,0.940332,3.229170,22.081035,0.734854
481,THLW20JLFS,gbtree,0.3,0.050,0,8,8,reg:squarederror,0,0.50,0.327310,0.021255,1.820770,0.525517,0.324124,0.319662,6.810414,4.577185,2.767866,3.421321,0.969169,0.018132,0.011622,0.013765,1.931559,0.477516,0.321785,0.312629,8.375477,4.917911,3.581263,4.147151,0.866027,0.058971,0.034237,0.017079,3.287577,0.272863,0.215407,0.114588,21.491795,4.849515,2.489132,1.512651,0.886317,0.028408,0.013125,0.007344,3.286517,0.141764,0.455104,0.130998,22.422067,1.763458,7.419437,2.740692,0.583303,0.091454,0.018539,0.010297,1.876165,7.592945,0.917598,3.287047,21.956931,0.734810
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
407,PU0D5UE8VA,gbtree,0.3,0.010,1,3,8,reg:squarederror,0,0.50,0.120450,0.005052,3.305067,0.390091,0.311659,0.110298,20.307858,6.842219,4.623384,3.840024,0.905754,0.025273,0.019521,0.014464,3.392150,0.320176,0.549120,0.330305,23.301355,5.510039,9.178413,6.059815,0.615510,0.049441,0.032394,0.046678,3.840901,0.369874,0.371768,0.205486,26.687055,6.433363,3.840979,1.765833,0.859071,0.037626,0.018058,0.010948,3.760946,0.277601,0.625793,0.376683,27.212297,4.161432,10.427568,5.885328,0.518553,0.067122,0.031021,0.012142,3.348609,21.804607,0.760632,3.800923,26.949676,0.688812
484,X1QLH1R59Q,gbtree,0.3,0.050,1,3,8,reg:squarederror,0,0.50,0.122280,0.010580,3.305067,0.390091,0.311659,0.110298,20.307858,6.842219,4.623384,3.840024,0.905754,0.025273,0.019521,0.014464,3.392150,0.320176,0.549120,0.330305,23.301355,5.510039,9.178413,6.059815,0.615510,0.049441,0.032394,0.046678,3.840901,0.369874,0.371768,0.205486,26.687055,6.433363,3.840979,1.765833,0.859071,0.037626,0.018058,0.010948,3.760946,0.277601,0.625793,0.376683,27.212297,4.161432,10.427568,5.885328,0.518553,0.067122,0.031021,0.012142,3.348609,21.804607,0.760632,3.800923,26.949676,0.688812
22,JZT07OTOL9,gbtree,0.3,0.010,2,3,8,reg:squarederror,0,0.50,0.117755,0.002122,3.334520,0.371739,0.306767,0.142722,20.800195,6.730383,4.624427,4.123360,0.903313,0.024994,0.019596,0.015782,3.384181,0.292160,0.568396,0.357472,23.186717,5.304911,9.301980,6.363616,0.618894,0.043071,0.036219,0.053647,3.864064,0.395355,0.341947,0.147624,28.313466,7.981361,4.119581,1.546728,0.853500,0.034187,0.016796,0.010128,3.762820,0.293683,0.622731,0.373497,27.044886,4.131728,10.202119,5.726336,0.520309,0.067329,0.028008,0.012232,3.359350,21.993456,0.761103,3.813442,27.679176,0.686904
288,90M82M5775,gbtree,0.3,0.050,2,3,8,reg:squarederror,0,0.50,0.125149,0.012759,3.334520,0.371739,0.306767,0.142722,20.800195,6.730383,4.624427,4.123360,0.903313,0.024994,0.019596,0.015782,3.384181,0.292160,0.568396,0.357472,23.186717,5.304911,9.301980,6.363616,0.618894,0.043071,0.036219,0.053647,3.864064,0.395355,0.341947,0.147624,28.313466,7.981361,4.119581,1.546728,0.853500,0.034187,0.016796,0.010128,3.762820,0.293683,0.622731,0.373497,27.044886,4.131728,10.202119,5.726336,0.520309,0.067329,0.028008,0.012232,3.359350,21.993456,0.761103,3.813442,27.679176,0.686904


In [7]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'booster': 'gbtree', 'eta': 0.3, 'gamma': 0.01, 'lambda': 2, 'max_depth': 8, 'nthread': 8, 'objective': 'reg:squarederror', 'seed': 0, 'subsample': 0.75}


## Best model

In [8]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_X_train_{}.npy'.format(JOINT, DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_X_test_{}.npy'.format(JOINT, DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_Y_train_{}.npy'.format(JOINT, DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_Y_test_{}.npy'.format(JOINT, DATA_ID)))

In [9]:
results = defaultdict(list)
tr_time = []
for target in range(Y_train.shape[1]):

    dtrain = xgb.DMatrix(data=X_train, label=Y_train[:, target])
    dtest = xgb.DMatrix(data=X_test, label=Y_test[:, target])

    callbacks = []#[xgb.callback.EarlyStopping(rounds=5, metric_name='rmse', maximize=False, save_best=True)]
    
    t_start = time.time()
    model = xgb.train(best_params, dtrain, evals=[(dtest, 'rmse')], callbacks=callbacks, verbose_eval=False)
    tr_time.append(time.time() - t_start)
    
    # Save the model
    model.save_model(os.path.join(RESULTS_PATH, DATA_ID, '{}_XGB_{}'.format(JOINT, HS_DATE), '{}_XGB_best_model_{}_{}_{}.joblib'.format(JOINT, target, HS_DATE, DATA_ID)))
    
    train_preds = model.predict(dtrain)
    test_preds = model.predict(dtest)

    results['Train_MAE'].append(mean_absolute_error(Y_train[:, target], train_preds))
    results['Train_MSE'].append(mean_squared_error(Y_train[:, target], train_preds))
    results['Train_R2'].append(r2_score(Y_train[:, target], train_preds))
    results['Test_MAE'].append(mean_absolute_error(Y_test[:, target], test_preds))
    results['Test_MSE'].append(mean_squared_error(Y_test[:, target], test_preds))
    results['Test_R2'].append(r2_score(Y_test[:, target], test_preds))

print('Training time: {:.4f}'.format(sum(tr_time)))

Training time: 0.4391


In [10]:
# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results['_'.join([subset, loss])][i + f] for i in range(0, len(CELLS) * 2, 2)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

Train Fx MAE: 1.6795 ± 0.1147
Train Fx MSE: 5.4172 ± 0.6823
Train Fx R2: 0.9746 ± 0.0033
Train Fy MAE: 1.4450 ± 0.1083
Train Fy MSE: 4.4168 ± 1.1980
Train Fy R2: 0.9231 ± 0.0050
Test Fx MAE: 3.8364 ± 0.3769
Test Fx MSE: 28.9411 ± 7.5544
Test Fx R2: 0.8268 ± 0.0286
Test Fy MAE: 3.0953 ± 0.0272
Test Fy MSE: 20.9479 ± 0.2038
Test Fy R2: 0.5000 ± 0.0444


In [11]:
str(model.get_score(importance_type='gain'))

"{'f14': 1150.5251614665997, 'f12': 777.8999008027935, 'f0': 291.4057120062687, 'f13': 292.075116744625, 'f2': 387.7356948588519, 'f1': 511.1852037187785, 'f3': 178.92452241502386, 'f11': 181.0763343104444, 'f10': 204.78689095260273, 'f9': 133.90674314658537, 'f5': 133.10687297824106, 'f4': 271.43591583834194, 'f7': 197.0478366264286, 'f8': 211.520428042831, 'f6': 87.50795867733665}"

In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()