In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
from joblib import dump, load
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.ensemble import RandomForestRegressor

## Configuration

In [2]:
# Force cells in the robotic leg
CELLS = [3, 4, 7, 8]

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0008_22072021'
# Hyperparameters search date
HS_DATE = '22072021'
# Number of folds in cross-validation
CV = 6

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0008_22072021


## Hyperparameters seach analysis

In [26]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_{}_*.json'.format(HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 212


In [27]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std
0,0053CDD6RU,50,0.5,10,20,10000,201.890455,1.777771,3.736151,0.112171,0.612247,0.038207,34.358118,2.026816,10.225985,0.764547,0.899108,0.006116,0.021815,0.002756,4.222927,0.032640,1.719379,0.038264,42.771542,0.599828,28.171035,0.701056,0.808222,0.003476,0.082249,0.004090,8.096202,0.727818,1.716026,0.549847,139.993334,26.463015,55.716064,24.496856,0.575438,0.087509,0.113608,0.045170,7.644204,0.415855,3.121978,0.442792,127.938003,10.030636,77.716880,12.841295,0.446647,0.071425,0.196926,0.100499
1,03LDFNOFGH,50,0.7,5,2,1000,30.840562,0.346915,2.650750,0.091164,0.425651,0.027319,17.590059,1.263665,5.101886,0.385094,0.948171,0.003863,0.011456,0.001543,3.047546,0.034458,1.178875,0.024499,22.656202,0.556882,14.389199,0.380046,0.893066,0.002750,0.050485,0.002207,7.969503,0.716748,1.670227,0.513883,136.023114,24.607592,53.232610,21.365071,0.584398,0.093226,0.110936,0.046993,7.439682,0.405848,3.040868,0.414780,122.112699,9.009124,74.259189,12.064453,0.467321,0.071489,0.195004,0.100667
2,061HB0EQT8,13,0.7,2,20,1000,28.579041,0.074339,3.701306,0.124426,0.581178,0.043149,30.530624,1.938824,8.382298,0.777786,0.909328,0.005910,0.021288,0.002607,4.218877,0.045854,1.710495,0.046744,40.590266,0.691653,27.352535,0.897998,0.816577,0.003311,0.080754,0.003808,8.096809,0.723109,1.670112,0.553595,138.758619,24.876122,53.505312,22.587005,0.576646,0.090707,0.111430,0.045420,7.613324,0.404981,3.109089,0.441891,126.621205,9.323795,77.020564,12.880402,0.449031,0.073805,0.196500,0.102723
3,0ECIQH6JIS,20,0.3,20,2,10000,112.413255,0.845373,4.870046,0.121530,0.828010,0.069738,57.782599,3.202461,18.961023,1.868171,0.833252,0.008480,0.031306,0.003976,5.466373,0.048516,2.308814,0.047691,70.862660,0.845228,47.607168,1.189424,0.701035,0.004139,0.116353,0.006726,8.244542,0.735303,1.721217,0.638640,145.031393,31.993636,58.421075,30.664391,0.565418,0.085620,0.110811,0.041208,7.919833,0.421240,3.256536,0.458714,136.494771,10.575385,83.757612,13.770335,0.420626,0.065226,0.192487,0.091239
4,0EXTVX4JY7,20,0.5,5,20,5000,105.714867,0.959715,3.447854,0.105891,0.548016,0.036185,28.880617,1.798413,8.154574,0.615347,0.914546,0.005362,0.019461,0.002397,3.922548,0.032420,1.570103,0.037402,36.733834,0.573335,24.005029,0.668698,0.831963,0.002941,0.075249,0.003455,8.060601,0.731816,1.685514,0.538854,138.252876,25.686860,54.051958,23.079408,0.578789,0.090948,0.112189,0.046226,7.586792,0.405176,3.100309,0.434901,126.218179,9.333729,76.868570,12.289142,0.453664,0.071122,0.193746,0.099320
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
207,BC1T31HIU7,20,0.2,2,20,5000,49.664191,0.247200,3.419503,0.105882,0.552961,0.034914,28.332406,1.838510,8.076045,0.625299,0.916231,0.005293,0.019103,0.002267,3.949402,0.033384,1.620042,0.040474,37.118754,0.475657,24.912948,0.650818,0.833285,0.002290,0.074292,0.002902,7.926188,0.745080,1.647215,0.551303,133.066775,29.108476,52.537593,25.572794,0.596162,0.092196,0.105124,0.047081,7.513793,0.389564,3.097328,0.405361,123.496239,8.233227,76.133571,10.772004,0.475607,0.061686,0.177181,0.081949
208,BHE44L93UJ,30,0.5,20,2,700,18.403573,1.291808,4.845809,0.125590,0.822179,0.064649,57.372918,3.139442,18.684843,1.804791,0.834296,0.008445,0.031244,0.003753,5.381271,0.045506,2.253431,0.044797,68.579233,0.814144,45.751485,1.087175,0.705521,0.005294,0.116809,0.007390,8.327788,0.717363,1.733939,0.629089,148.004097,29.528582,58.889610,29.212571,0.554654,0.083351,0.115448,0.041840,7.932509,0.425297,3.245599,0.477763,137.414414,11.237110,83.667982,14.797595,0.410027,0.072875,0.202736,0.102238
209,BRTLGJPKYW,13,0.1,2,20,700,3.430241,0.019983,4.080705,0.109129,0.673367,0.060818,37.681116,2.267017,11.186947,1.193662,0.889249,0.006340,0.024365,0.002744,4.690411,0.027534,1.985428,0.056530,50.705503,0.315484,34.824675,1.144651,0.785250,0.003148,0.087219,0.004778,7.867259,0.738031,1.614450,0.617538,130.491420,31.964963,52.412331,29.432938,0.608136,0.087927,0.097288,0.043724,7.587417,0.371288,3.162164,0.366546,125.657138,6.097012,79.004419,8.872032,0.484099,0.042553,0.155712,0.058834
210,BWKXE5384C,13,0.1,20,5,2500,10.739815,0.058693,5.306756,0.122194,0.938883,0.090410,67.096823,3.834928,23.270312,2.767860,0.807929,0.009196,0.034156,0.005081,5.964228,0.048625,2.535105,0.054577,84.272934,0.691187,56.131484,1.522760,0.662715,0.002888,0.116457,0.007373,8.130360,0.754062,1.687388,0.681891,140.239078,35.633027,57.937362,33.278516,0.584598,0.085457,0.098158,0.039037,8.014023,0.408730,3.359197,0.357674,140.656281,6.848292,89.073467,9.167788,0.431026,0.039569,0.159730,0.059318


In [28]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy']]].mean(axis=1)

In [29]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Train_MAE,Train_MSE,Train_R2,Valid_MAE,Valid_MSE,Valid_R2
74,V24FH0UC1W,15,0.1,2,2,10000,55.433815,0.365775,2.296871,0.079446,0.396317,0.037432,11.732212,0.920337,3.385014,0.403451,0.965310,0.002982,0.008106,0.001118,2.815122,0.035594,1.208750,0.041252,18.799030,0.436795,13.565041,0.554038,0.920860,0.002329,0.034056,0.002084,7.651563,0.755338,1.584823,0.523207,124.222539,30.033897,49.943362,25.602384,0.622925,0.092683,0.096605,0.050295,7.252702,0.365828,3.034126,0.352281,116.585877,6.276876,73.496477,8.092109,0.519276,0.045046,0.153539,0.061081,2.555997,15.265621,0.943085,7.452133,120.404208,0.571101
98,7R02CFX48U,50,0.1,2,2,10000,59.208885,0.235305,1.850461,0.065129,0.328078,0.022181,8.687453,0.688850,2.672017,0.246840,0.974567,0.002000,0.005556,0.000717,2.228244,0.024905,0.926811,0.021581,12.490288,0.224083,8.491647,0.185703,0.945126,0.001149,0.024331,0.000832,7.650978,0.762719,1.585448,0.509426,124.582629,30.095099,50.004981,25.199315,0.621185,0.094385,0.097542,0.051336,7.228075,0.368219,3.020626,0.346609,116.276767,6.349649,73.131579,8.037413,0.519680,0.046109,0.153839,0.062024,2.039352,10.588870,0.959846,7.439526,120.429698,0.570432
191,Q23U62EVVF,50,0.1,2,2,700,4.252936,0.050852,1.856301,0.065399,0.330026,0.022740,8.780408,0.716602,2.722733,0.269942,0.974334,0.002039,0.005541,0.000725,2.233361,0.025622,0.928603,0.021984,12.584609,0.226383,8.532583,0.197831,0.944754,0.001143,0.024432,0.000867,7.658540,0.756697,1.581664,0.507181,124.765552,30.128455,49.912121,25.209502,0.620543,0.093914,0.097413,0.051405,7.229480,0.362742,3.022107,0.345232,116.251120,6.028564,73.064783,8.065914,0.519754,0.045449,0.154187,0.061513,2.044831,10.682509,0.959544,7.444010,120.508336,0.570149
64,TSHBOIBUXK,15,0.1,2,5,10000,56.254231,0.470965,2.404137,0.081543,0.411103,0.038023,12.927218,0.982097,3.705170,0.418853,0.961756,0.003161,0.008917,0.001218,2.930939,0.035743,1.246868,0.041493,20.315568,0.443831,14.475270,0.560693,0.913461,0.002337,0.037219,0.002066,7.662717,0.755368,1.586277,0.524433,124.501903,30.077141,50.081352,25.677183,0.622141,0.092595,0.096918,0.050186,7.268472,0.365178,3.039564,0.353230,117.035985,6.208482,73.789147,8.043481,0.517426,0.045094,0.153644,0.060964,2.667538,16.621393,0.937609,7.465595,120.768944,0.569784
176,6BZMRVEYPP,13,0.1,2,2,5000,26.528616,0.194066,2.897698,0.089359,0.483617,0.053530,17.259588,1.185368,4.746406,0.658424,0.948507,0.004127,0.012973,0.001703,3.508937,0.038752,1.529810,0.056306,28.292472,0.527744,21.109705,1.003738,0.883873,0.003317,0.051255,0.003657,7.674408,0.747361,1.579029,0.548122,124.114587,29.909213,49.686382,26.139441,0.623753,0.090377,0.095367,0.048659,7.317514,0.365817,3.063096,0.361134,117.733043,6.042328,74.332495,8.317944,0.515317,0.043497,0.153876,0.059840,3.203317,22.776030,0.916190,7.495961,120.923815,0.569535
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
152,XDNUMMILO7,13,0.7,20,20,2500,62.918236,0.525364,4.905542,0.129178,0.833588,0.064485,58.080089,3.181356,19.001450,1.868398,0.832433,0.008522,0.031305,0.003350,5.407495,0.050577,2.251718,0.045039,68.628086,0.963211,45.666255,1.085827,0.702972,0.006295,0.118944,0.007913,8.355847,0.725355,1.724815,0.625518,148.809856,29.047404,58.806723,28.513347,0.551564,0.083291,0.115645,0.041626,7.928605,0.418173,3.235581,0.482264,137.269209,11.273713,83.374974,15.296492,0.406856,0.075888,0.206667,0.105357,5.156518,63.354088,0.767702,8.142226,143.039533,0.479210
48,RCRGRKFCW9,30,0.7,20,20,10000,254.881377,2.790281,4.833324,0.128757,0.822476,0.061009,57.184032,3.178526,18.766607,1.786120,0.835109,0.008468,0.030631,0.003328,5.331871,0.043171,2.219512,0.043786,67.279811,0.810996,44.704871,1.044709,0.708011,0.005822,0.117517,0.007732,8.355820,0.723438,1.721319,0.624316,148.980683,29.060457,58.787117,28.407769,0.550807,0.083878,0.115861,0.041477,7.920039,0.418649,3.231401,0.479264,137.106428,11.215611,83.218288,15.232838,0.407201,0.076154,0.206561,0.105388,5.082597,62.231921,0.771560,8.137930,143.043556,0.479004
59,TJNK0YLS4U,13,0.7,20,2,5000,125.448766,0.872667,4.905307,0.129011,0.833790,0.064450,58.078244,3.182132,19.006191,1.866506,0.832446,0.008486,0.031293,0.003309,5.406831,0.050340,2.251158,0.045044,68.606687,0.957676,45.644422,1.089533,0.703006,0.006348,0.118951,0.007886,8.358041,0.725403,1.724809,0.624590,148.863992,28.967740,58.790478,28.439971,0.551318,0.083573,0.115639,0.041557,7.929272,0.415525,3.235342,0.482308,137.264910,11.184588,83.350102,15.332423,0.406670,0.076087,0.206744,0.105547,5.156069,63.342466,0.767726,8.143657,143.064451,0.478994
211,BY4P0SL7CY,20,0.7,20,20,5000,127.390901,1.290989,4.833548,0.128578,0.822289,0.061105,57.199533,3.181340,18.762961,1.787543,0.835056,0.008450,0.030659,0.003311,5.332154,0.042680,2.219300,0.043973,67.297564,0.806009,44.701793,1.052942,0.707969,0.005840,0.117458,0.007748,8.355812,0.722714,1.720947,0.623470,148.993903,29.018487,58.770261,28.344563,0.550725,0.083839,0.115866,0.041384,7.921134,0.419125,3.232245,0.478320,137.137621,11.197130,83.239444,15.206857,0.407156,0.076079,0.206599,0.105365,5.082851,62.248549,0.771513,8.138473,143.065762,0.478940


In [30]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'max_depth': 15, 'max_features': 0.1, 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 10000}


## Best model

In [17]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_train_{}.npy'.format(DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_test_{}.npy'.format(DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_train_{}.npy'.format(DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_test_{}.npy'.format(DATA_ID)))

In [18]:
# Setup the model with the best parameters
model = RandomForestRegressor(**best_params, random_state=0, n_jobs=-1, verbose=1)

model.fit(X_train, Y_train)

# Save the model
dump(model, os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_best_model_{}_{}.joblib'.format(HS_DATE, DATA_ID))) 

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    0.6s
[Parallel(n_jobs=-1)]: Done 184 tasks      | elapsed:    2.2s
[Parallel(n_jobs=-1)]: Done 434 tasks      | elapsed:    5.2s
[Parallel(n_jobs=-1)]: Done 784 tasks      | elapsed:    9.3s
[Parallel(n_jobs=-1)]: Done 1234 tasks      | elapsed:   14.3s
[Parallel(n_jobs=-1)]: Done 1784 tasks      | elapsed:   20.5s
[Parallel(n_jobs=-1)]: Done 2434 tasks      | elapsed:   28.2s
[Parallel(n_jobs=-1)]: Done 3184 tasks      | elapsed:   37.2s
[Parallel(n_jobs=-1)]: Done 4034 tasks      | elapsed:   47.5s
[Parallel(n_jobs=-1)]: Done 4984 tasks      | elapsed:   59.5s
[Parallel(n_jobs=-1)]: Done 6034 tasks      | elapsed:  1.2min
[Parallel(n_jobs=-1)]: Done 7184 tasks      | elapsed:  1.5min
[Parallel(n_jobs=-1)]: Done 8434 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done 9784 tasks      | elapsed:  2.0min
[Parallel(n_jobs=-1)]: Done 10000 out of 

['../../../../results/0007_19072021/RF_21072021/RF_best_model_21072021_0007_19072021.joblib']

In [20]:
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

results = {
    'Train': {
        'MAE': mean_absolute_error(Y_train, train_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_train, train_preds, multioutput='raw_values'),
        'R2': r2_score(Y_train, train_preds, multioutput='raw_values')
    },
    'Test': {
        'MAE': mean_absolute_error(Y_test, test_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_test, test_preds, multioutput='raw_values'),
        'R2': r2_score(Y_test, test_preds, multioutput='raw_values')
    }       
    
}

# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results[subset][loss][i + f] for i in range(0, len(CELLS) * 2, 2)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done 434 tasks      | elapsed:    0.4s
[Parallel(n_jobs=8)]: Done 784 tasks      | elapsed:    0.6s
[Parallel(n_jobs=8)]: Done 1234 tasks      | elapsed:    0.9s
[Parallel(n_jobs=8)]: Done 1784 tasks      | elapsed:    1.4s
[Parallel(n_jobs=8)]: Done 2434 tasks      | elapsed:    1.9s
[Parallel(n_jobs=8)]: Done 3184 tasks      | elapsed:    2.5s
[Parallel(n_jobs=8)]: Done 4034 tasks      | elapsed:    3.1s
[Parallel(n_jobs=8)]: Done 4984 tasks      | elapsed:    3.8s
[Parallel(n_jobs=8)]: Done 6034 tasks      | elapsed:    4.6s
[Parallel(n_jobs=8)]: Done 7184 tasks      | elapsed:    5.4s
[Parallel(n_jobs=8)]: Done 8434 tasks      | elapsed:    6.3s
[Parallel(n_jobs=8)]: Done 9784 tasks      | elapsed:    7.2s
[Parallel(n_jobs=8)]: Done 10000 out of 10000 | elapsed:

Train Fx MAE: 2.2904 ± 0.4356
Train Fx MSE: 13.1327 ± 3.9882
Train Fx R2: 0.9611 ± 0.0104
Train Fy MAE: 2.7660 ± 1.1741
Train Fy MSE: 19.6227 ± 14.1450
Train Fy R2: 0.9152 ± 0.0379
Test Fx MAE: 7.9542 ± 2.5406
Test Fx MSE: 164.2414 ± 107.3285
Test Fx R2: 0.6927 ± 0.0674
Test Fy MAE: 7.0967 ± 2.7796
Test Fy MSE: 117.8615 ± 70.1520
Test Fy R2: 0.5324 ± 0.1616


[Parallel(n_jobs=8)]: Done 9784 tasks      | elapsed:    2.2s
[Parallel(n_jobs=8)]: Done 10000 out of 10000 | elapsed:    2.3s finished


In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()