In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
from joblib import dump, load
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.ensemble import RandomForestRegressor

## Configuration

In [2]:
JOINT = 'Hip'
FORCE_CELLS_PER_JOINT = {
    'Hip': [5, 6],
    'Knee': [3, 4, 7, 8],
    'Ankle': [1, 2]
}

CELLS = FORCE_CELLS_PER_JOINT[JOINT]

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0013_09082021'
# Hyperparameters search date
HS_DATE = '13082021'
# Number of folds in cross-validation
CV = 4

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0013_09082021


## Hyperparameters seach analysis

In [3]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, '{}_RF_{}'.format(JOINT, HS_DATE), '{}_RF_{}_*.json'.format(JOINT, HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 675


In [4]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std
0,QDZAX41D9Q,20,0.3,0.00100,0.00100,1000,2.706464,0.022511,3.149630,0.158886,0.278632,0.050678,19.163644,1.950161,2.190036,0.843340,0.915353,0.007399,0.032643,0.003907,4.360570,0.326650,2.546687,0.236264,68.791261,12.436380,61.057767,11.885777,0.829841,0.016720,0.035372,0.003967,5.982367,0.654415,0.651217,0.172090,71.587360,16.348910,14.474156,7.247331,0.668853,0.056490,0.155642,0.049055,9.722560,1.376287,6.682795,0.946424,391.585124,151.078388,371.627425,145.993267,0.291945,0.177418,0.181136,0.141904
1,N33RIL290B,30,0.7,0.00001,0.00100,2500,14.037893,0.180055,1.794693,0.092213,0.158235,0.030370,6.401733,0.692608,0.812048,0.280144,0.971604,0.002906,0.011263,0.001340,2.459779,0.188945,1.372280,0.134768,21.863629,4.278323,19.140669,4.035088,0.942477,0.007286,0.014878,0.001519,5.800237,0.768904,0.708694,0.244829,69.371692,19.267711,16.354897,9.550552,0.680535,0.067032,0.157816,0.049350,9.696751,1.534430,6.706909,1.035368,403.842293,173.030758,383.769955,167.190197,0.280752,0.200800,0.190760,0.150737
2,4YEIECWXUS,50,0.5,0.00001,0.00100,1000,4.236644,0.025225,1.822474,0.095098,0.161361,0.029502,6.571247,0.714005,0.837884,0.277078,0.970850,0.002933,0.011577,0.001337,2.499177,0.193752,1.399101,0.137215,22.509393,4.460199,19.725679,4.208216,0.941046,0.007371,0.015046,0.001447,5.821770,0.759453,0.705991,0.241566,69.704169,18.966877,16.283949,9.318270,0.678634,0.067637,0.158377,0.050923,9.715673,1.537423,6.735086,1.051802,405.388563,170.274906,385.688536,164.536951,0.281501,0.205232,0.198803,0.157655
3,NYAU40XASJ,20,0.5,0.00010,0.00100,2500,10.424779,0.068942,1.930844,0.110782,0.186985,0.033909,7.262452,0.822081,1.076481,0.327577,0.967614,0.003294,0.013430,0.001511,2.652606,0.204807,1.504714,0.137251,24.603940,4.646664,21.617870,4.318806,0.936283,0.008280,0.015654,0.001851,5.826723,0.758486,0.707745,0.246295,69.648296,18.938781,16.177002,9.255142,0.679021,0.067257,0.157766,0.050416,9.709117,1.518099,6.732614,1.024832,403.088395,168.883908,383.379843,163.095701,0.284128,0.203210,0.196094,0.155553
4,D6HOLH8Y1Z,50,0.7,0.00100,0.00001,2500,12.021019,0.153341,2.927510,0.143552,0.251945,0.049630,16.798605,1.630120,1.763493,0.735418,0.925895,0.007158,0.028032,0.003730,3.983625,0.287914,2.261305,0.206219,57.661630,9.842267,50.552989,9.364274,0.849031,0.016914,0.038061,0.004209,5.943129,0.695320,0.672253,0.184021,71.484063,17.083002,14.895504,7.686860,0.670159,0.055793,0.156435,0.048886,9.676396,1.404024,6.618640,0.962386,389.845389,154.196817,369.354249,148.963969,0.289877,0.178181,0.172197,0.142755
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
670,DDHUQ6S6M2,13,0.5,0.00100,0.00010,2500,8.728270,0.081964,3.337589,0.158018,0.330825,0.062143,21.087042,2.015397,2.698869,0.990023,0.906436,0.008486,0.037196,0.004455,4.723689,0.326201,2.823828,0.232703,76.561936,12.058423,68.337752,11.462960,0.815161,0.019336,0.033882,0.004976,6.039752,0.654253,0.683157,0.184683,72.928149,16.395609,14.894585,7.497414,0.662573,0.055565,0.159037,0.049191,9.706589,1.348204,6.650091,0.905593,390.162216,148.869727,369.895542,143.638082,0.289162,0.181216,0.177580,0.146078
671,PT1OC6BJ0U,20,0.1,0.00100,0.00010,1000,1.667082,0.007134,3.781660,0.176350,0.357567,0.046614,27.187825,2.578547,3.557255,1.087754,0.879322,0.009443,0.048292,0.005693,5.440040,0.413635,3.320952,0.306740,105.171512,20.086194,94.776385,19.329342,0.758711,0.019319,0.034629,0.006528,6.141316,0.667687,0.665180,0.172807,73.099631,16.519164,14.376106,7.628662,0.662645,0.050584,0.156083,0.041277,9.857425,1.503232,6.762318,1.059599,384.386223,163.662842,363.588976,158.903364,0.302260,0.157655,0.148570,0.113274
672,R8XMNSLVV9,50,0.7,0.00100,0.00010,2500,11.994702,0.158911,2.927510,0.143552,0.251945,0.049630,16.798605,1.630120,1.763493,0.735418,0.925895,0.007158,0.028032,0.003730,3.983625,0.287914,2.261305,0.206219,57.661630,9.842267,50.552989,9.364274,0.849031,0.016914,0.038061,0.004209,5.943129,0.695320,0.672253,0.184021,71.484063,17.083002,14.895504,7.686860,0.670159,0.055793,0.156435,0.048886,9.676396,1.404024,6.618640,0.962386,389.845389,154.196817,369.354249,148.963969,0.289877,0.178181,0.172197,0.142755
673,L95NVFEXNG,30,0.2,0.00010,0.00001,2500,7.324383,0.072536,1.315266,0.075283,0.128726,0.023352,3.571437,0.412740,0.498296,0.171182,0.984126,0.001548,0.006457,0.000749,1.974686,0.168096,1.245569,0.125438,15.239174,3.198049,13.941751,3.071052,0.967659,0.004113,0.002142,0.001050,5.783319,0.738542,0.666545,0.228127,68.845856,18.810254,15.687362,8.852924,0.682462,0.068324,0.155055,0.049811,9.691196,1.558196,6.754384,1.079241,403.690140,173.809846,384.556161,168.182984,0.293363,0.202603,0.199903,0.146799


In [5]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy']]].mean(axis=1)

In [6]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Train_MAE,Train_MSE,Train_R2,Valid_MAE,Valid_MSE,Valid_R2
405,LDY9X7IJP7,30,0.1,0.00001,0.00010,2500,5.062267,0.008730,1.424746,0.077926,0.138871,0.023122,4.084950,0.450140,0.557182,0.193270,0.981858,0.001658,0.007330,0.000860,2.153944,0.179503,1.365095,0.134915,17.405300,3.697777,15.918405,3.567899,0.963044,0.004114,0.002546,0.001213,5.821914,0.726013,0.666829,0.206732,68.837270,18.840298,15.755592,9.293891,0.682163,0.066610,0.154895,0.048179,9.716294,1.651805,6.772706,1.191247,401.011939,186.210922,381.913845,181.161912,0.306186,0.197140,0.188337,0.135128,1.789345,10.745125,0.972451,7.769104,234.924605,0.494174
300,P4J6B05SL7,30,0.1,0.00010,0.00010,2500,5.047583,0.013552,1.424746,0.077926,0.138871,0.023122,4.084950,0.450140,0.557182,0.193270,0.981858,0.001658,0.007330,0.000860,2.153944,0.179503,1.365095,0.134915,17.405300,3.697777,15.918405,3.567899,0.963044,0.004114,0.002546,0.001213,5.821914,0.726013,0.666829,0.206732,68.837270,18.840298,15.755592,9.293891,0.682163,0.066610,0.154895,0.048179,9.716294,1.651805,6.772706,1.191247,401.011939,186.210922,381.913845,181.161912,0.306186,0.197140,0.188337,0.135128,1.789345,10.745125,0.972451,7.769104,234.924605,0.494174
614,W11TLWHTOL,30,0.1,0.00001,0.00001,2500,4.976541,0.024706,1.424746,0.077926,0.138871,0.023122,4.084950,0.450140,0.557182,0.193270,0.981858,0.001658,0.007330,0.000860,2.153944,0.179503,1.365095,0.134915,17.405300,3.697777,15.918405,3.567899,0.963044,0.004114,0.002546,0.001213,5.821914,0.726013,0.666829,0.206732,68.837270,18.840298,15.755592,9.293891,0.682163,0.066610,0.154895,0.048179,9.716294,1.651805,6.772706,1.191247,401.011939,186.210922,381.913845,181.161912,0.306186,0.197140,0.188337,0.135128,1.789345,10.745125,0.972451,7.769104,234.924605,0.494174
439,O7LEIHJTO3,30,0.1,0.00010,0.00001,2500,5.042455,0.008222,1.424746,0.077926,0.138871,0.023122,4.084950,0.450140,0.557182,0.193270,0.981858,0.001658,0.007330,0.000860,2.153944,0.179503,1.365095,0.134915,17.405300,3.697777,15.918405,3.567899,0.963044,0.004114,0.002546,0.001213,5.821914,0.726013,0.666829,0.206732,68.837270,18.840298,15.755592,9.293891,0.682163,0.066610,0.154895,0.048179,9.716294,1.651805,6.772706,1.191247,401.011939,186.210922,381.913845,181.161912,0.306186,0.197140,0.188337,0.135128,1.789345,10.745125,0.972451,7.769104,234.924605,0.494174
354,2VX2B02OCI,30,0.1,0.00010,0.00010,1000,2.010655,0.015507,1.426980,0.078377,0.139530,0.022876,4.097623,0.453896,0.564483,0.194676,0.981795,0.001677,0.007376,0.000870,2.153678,0.179306,1.364051,0.134557,17.389738,3.702126,15.901711,3.575452,0.963046,0.004100,0.002571,0.001321,5.822791,0.725209,0.665043,0.209668,68.833591,18.811968,15.661992,9.217289,0.682217,0.066858,0.154524,0.048418,9.717465,1.644471,6.772904,1.187852,400.435943,184.620857,381.326167,179.574749,0.305868,0.196768,0.188057,0.136142,1.790329,10.743681,0.972421,7.770128,234.634767,0.494042
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
565,VYR2MH2TH8,13,0.5,0.00100,0.00010,700,2.442065,0.027420,3.338153,0.157470,0.330694,0.061109,21.109586,2.022034,2.703306,0.986546,0.906335,0.008495,0.037241,0.004442,4.725910,0.325079,2.824768,0.231226,76.708144,12.238389,68.479392,11.637994,0.814952,0.019673,0.033824,0.004856,6.043590,0.653123,0.680258,0.185602,73.030794,16.392186,14.835586,7.542519,0.662325,0.054769,0.158780,0.048795,9.711567,1.339709,6.652797,0.896073,391.367004,148.793958,371.068083,143.535203,0.287214,0.181390,0.178862,0.147871,4.032031,48.908865,0.860643,7.877579,232.198899,0.474769
140,ZP6ZPPE78V,13,0.5,0.00100,0.00100,700,2.439470,0.026997,3.338153,0.157470,0.330694,0.061109,21.109586,2.022034,2.703306,0.986546,0.906335,0.008495,0.037241,0.004442,4.725910,0.325079,2.824768,0.231226,76.708144,12.238389,68.479392,11.637994,0.814952,0.019673,0.033824,0.004856,6.043590,0.653123,0.680258,0.185602,73.030794,16.392186,14.835586,7.542519,0.662325,0.054769,0.158780,0.048795,9.711567,1.339709,6.652797,0.896073,391.367004,148.793958,371.068083,143.535203,0.287214,0.181390,0.178862,0.147871,4.032031,48.908865,0.860643,7.877579,232.198899,0.474769
5,QOO2KXMF8E,13,0.5,0.00100,0.00010,1000,3.518824,0.021741,3.338052,0.159060,0.330916,0.061670,21.098893,2.030809,2.701613,0.989073,0.906386,0.008507,0.037218,0.004438,4.723971,0.326487,2.823578,0.233312,76.609886,12.198596,68.384657,11.604627,0.815107,0.019522,0.033876,0.004980,6.044723,0.652355,0.684429,0.184069,73.036743,16.349264,14.891498,7.542429,0.662116,0.054872,0.159102,0.048866,9.715839,1.345654,6.657023,0.903888,391.560290,149.503431,371.264317,144.285875,0.287043,0.181611,0.179037,0.148121,4.031011,48.854390,0.860747,7.880281,232.298516,0.474580
313,4WIFI6RUH8,13,0.5,0.00100,0.00001,1000,3.454989,0.017333,3.338052,0.159060,0.330916,0.061670,21.098893,2.030809,2.701613,0.989073,0.906386,0.008507,0.037218,0.004438,4.723971,0.326487,2.823578,0.233312,76.609886,12.198596,68.384657,11.604627,0.815107,0.019522,0.033876,0.004980,6.044723,0.652355,0.684429,0.184069,73.036743,16.349264,14.891498,7.542429,0.662116,0.054872,0.159102,0.048866,9.715839,1.345654,6.657023,0.903888,391.560290,149.503431,371.264317,144.285875,0.287043,0.181611,0.179037,0.148121,4.031011,48.854390,0.860747,7.880281,232.298516,0.474580


In [7]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'max_depth': 30, 'max_features': 0.1, 'min_samples_leaf': 1e-05, 'min_samples_split': 0.0001, 'n_estimators': 2500}


## Best model

In [8]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_X_train_{}.npy'.format(JOINT, DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_X_test_{}.npy'.format(JOINT, DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_Y_train_{}.npy'.format(JOINT, DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_Y_test_{}.npy'.format(JOINT, DATA_ID)))

In [9]:
# Setup the model with the best parameters
model = RandomForestRegressor(**best_params, random_state=0, n_jobs=-1, verbose=1)

model.fit(X_train, Y_train)

# Save the model
dump(model, os.path.join(RESULTS_PATH, DATA_ID, '{}_RF_{}'.format(JOINT, HS_DATE), '{}_RF_best_model_{}_{}.joblib'.format(JOINT, HS_DATE, DATA_ID))) 

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    0.1s
[Parallel(n_jobs=-1)]: Done 184 tasks      | elapsed:    0.4s
[Parallel(n_jobs=-1)]: Done 434 tasks      | elapsed:    0.9s
[Parallel(n_jobs=-1)]: Done 784 tasks      | elapsed:    1.7s
[Parallel(n_jobs=-1)]: Done 1234 tasks      | elapsed:    2.7s
[Parallel(n_jobs=-1)]: Done 1784 tasks      | elapsed:    3.9s
[Parallel(n_jobs=-1)]: Done 2434 tasks      | elapsed:    5.3s
[Parallel(n_jobs=-1)]: Done 2500 out of 2500 | elapsed:    5.5s finished


['../../../../results/0013_09082021/Hip_RF_13082021/Hip_RF_best_model_13082021_0013_09082021.joblib']

In [10]:
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

results = {
    'Train': {
        'MAE': mean_absolute_error(Y_train, train_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_train, train_preds, multioutput='raw_values'),
        'R2': r2_score(Y_train, train_preds, multioutput='raw_values')
    },
    'Test': {
        'MAE': mean_absolute_error(Y_test, test_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_test, test_preds, multioutput='raw_values'),
        'R2': r2_score(Y_test, test_preds, multioutput='raw_values')
    }       
    
}

# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results[subset][loss][i + f] for i in range(0, len(CELLS) * 2, 2)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done 434 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done 784 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done 1234 tasks      | elapsed:    0.3s
[Parallel(n_jobs=8)]: Done 1784 tasks      | elapsed:    0.5s
[Parallel(n_jobs=8)]: Done 2434 tasks      | elapsed:    0.7s
[Parallel(n_jobs=8)]: Done 2500 out of 2500 | elapsed:    0.7s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 434 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done 784 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done 1234 tasks      | elapsed:    0.2s


Train Fx MAE: 1.4383 ± 0.1453
Train Fx MSE: 4.1693 ± 0.6342
Train Fx R2: 0.9814 ± 0.0078
Train Fy MAE: 2.2358 ± 1.4441
Train Fy MSE: 19.0880 ± 17.6000
Train Fy R2: 0.9614 ± 0.0008
Test Fx MAE: 5.8537 ± 0.5953
Test Fx MSE: 70.5377 ± 19.4663
Test Fx R2: 0.7938 ± 0.0597
Test Fy MAE: 9.6492 ± 6.9066
Test Fy MSE: 376.1152 ± 359.6698
Test Fy R2: 0.6109 ± 0.0379


[Parallel(n_jobs=8)]: Done 1784 tasks      | elapsed:    0.3s
[Parallel(n_jobs=8)]: Done 2434 tasks      | elapsed:    0.4s
[Parallel(n_jobs=8)]: Done 2500 out of 2500 | elapsed:    0.4s finished


In [11]:
model.feature_importances_

array([0.06663225, 0.05709593, 0.06912625, 0.06582222, 0.05898258,
       0.07084185, 0.06822455, 0.05788318, 0.07201463, 0.06903547,
       0.0604749 , 0.07193388, 0.07474517, 0.06273663, 0.0744505 ])

In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()