In [10]:
import pandas as pd
import numpy as np
import os
import glob
import json
from joblib import dump, load
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.ensemble import RandomForestRegressor

## Configuration

In [11]:
# Force cells in the robotic leg
CELLS = [3, 4, 7, 8]

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0007_19072021'
# Hyperparameters search date
HS_DATE = '21072021'
# Number of folds in cross-validation
CV = 6

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0007_19072021


## Hyperparameters seach analysis

In [12]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_{}_*.json'.format(HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 145


In [13]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std
0,00B7JPV9HJ,30,5,20,20,10000,200.494256,2.268230,4.333499,0.125969,0.802363,0.039455,50.590779,3.150419,16.457793,1.468712,0.852207,0.009818,0.034202,0.004188,4.790325,0.052797,1.910805,0.026349,57.650964,0.983040,36.450471,0.854469,0.738047,0.002263,0.111072,0.005059,7.734570,0.622075,1.805535,0.469639,136.936914,30.809410,61.313807,31.651977,0.597466,0.089385,0.094848,0.045376,7.498600,0.507893,3.014268,0.420703,126.504364,16.510835,77.160412,16.283083,0.436291,0.043801,0.173882,0.065707
1,05KSHINUOE,13,4,5,20,2500,43.129761,0.314429,3.518324,0.116267,0.634585,0.029905,30.846941,2.308118,9.359289,0.963340,0.908443,0.007361,0.023982,0.002783,4.072298,0.046581,1.639519,0.023473,40.431326,0.679772,26.575191,0.551940,0.816070,0.003161,0.080179,0.003453,7.499704,0.630351,1.730194,0.442848,128.321588,29.775285,57.108110,29.862738,0.620118,0.088846,0.098063,0.043783,7.221304,0.507090,2.913568,0.423569,116.780222,15.340174,71.593289,15.307783,0.478640,0.032903,0.164655,0.059265
2,0RFPOMBD0R,50,3,2,20,5000,75.903087,1.151411,2.924581,0.101239,0.521352,0.026375,23.570880,1.879249,6.884185,0.808395,0.929561,0.006215,0.019056,0.002547,3.425352,0.047184,1.344653,0.017195,30.292951,0.738079,19.111265,0.280555,0.857174,0.003163,0.067534,0.002187,7.422904,0.599233,1.709292,0.414144,125.839425,28.947955,55.748032,28.900149,0.626241,0.090638,0.097754,0.045018,7.122659,0.479535,2.880868,0.395743,113.443329,13.656001,69.611529,14.425677,0.492926,0.029002,0.162222,0.056609
3,196JFNVFEN,13,5,5,5,700,15.204689,0.111135,3.218389,0.107296,0.592915,0.019716,25.855448,1.706239,8.198374,0.709650,0.923608,0.005444,0.019872,0.002139,3.764299,0.065539,1.536703,0.022972,35.399537,0.835061,24.076650,0.707125,0.842489,0.003075,0.067421,0.003306,7.586872,0.668327,1.741243,0.491116,133.057623,31.295869,58.687449,31.304789,0.605734,0.092142,0.100530,0.043015,7.285800,0.514324,2.942949,0.413900,120.498676,15.090717,73.876797,15.008477,0.461684,0.029012,0.170054,0.067095
4,1PUVDLU8SP,15,2,10,2,1000,10.247102,0.087978,3.658773,0.112425,0.687570,0.042160,36.121668,2.685466,11.551980,1.346179,0.893978,0.008586,0.025737,0.003223,4.182826,0.036727,1.705752,0.029237,43.937016,0.648026,28.490381,0.610278,0.801381,0.003457,0.085983,0.003149,7.475048,0.566455,1.685056,0.454992,126.076687,29.742728,55.596906,30.484349,0.627089,0.085714,0.095555,0.043037,7.223407,0.437482,2.938706,0.383087,115.603484,12.349489,71.424036,13.415759,0.488296,0.032625,0.159484,0.054686
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
140,Z5IB15JLSB,13,4,10,2,5000,84.930470,0.668064,3.702540,0.118785,0.677537,0.033467,35.368980,2.621823,11.087026,1.161135,0.895763,0.008271,0.025920,0.003009,4.231224,0.045664,1.707233,0.025110,44.103669,0.677184,28.701668,0.572200,0.799854,0.003291,0.086314,0.003654,7.511294,0.624662,1.738124,0.439297,128.662883,29.794593,57.562685,29.971368,0.620095,0.088533,0.096479,0.044007,7.253019,0.510686,2.930720,0.420350,117.832672,15.468870,72.260275,15.261113,0.475033,0.034402,0.165025,0.058707
141,Z5R857D1VG,50,3,20,5,1000,13.029638,0.211026,4.343214,0.122266,0.815542,0.050510,50.077704,3.469772,16.569937,1.865317,0.853999,0.010794,0.033916,0.004086,4.856243,0.045062,1.964703,0.029964,58.061777,0.819316,37.191321,0.745702,0.738058,0.003116,0.110721,0.004356,7.584224,0.588660,1.726397,0.433910,129.300345,29.965216,57.204938,30.371066,0.618503,0.086207,0.094544,0.044622,7.389166,0.468842,2.995509,0.409943,121.244780,14.179930,74.893857,14.860636,0.461942,0.037524,0.167866,0.057920
142,ZFTRVYIDHE,20,4,10,2,700,12.640622,0.444729,3.311889,0.111723,0.600591,0.030129,31.541020,2.554824,9.678015,1.065828,0.906850,0.008207,0.023146,0.003066,3.793719,0.046704,1.499881,0.021092,37.762706,0.781744,23.737066,0.409808,0.824325,0.003773,0.079442,0.002727,7.494584,0.626687,1.736546,0.440749,128.882911,29.779534,57.715397,29.814324,0.619237,0.089989,0.096770,0.044586,7.219600,0.524853,2.921813,0.412584,117.797417,15.824810,72.388790,15.360704,0.475482,0.036117,0.166247,0.057257
143,ZPJ716JWZZ,20,2,5,5,5000,56.334870,0.437882,2.649625,0.092427,0.497022,0.028493,20.763799,1.685279,6.290347,0.735378,0.938521,0.005663,0.015543,0.002275,3.104097,0.039454,1.247870,0.018542,26.034863,0.608813,16.643428,0.282538,0.879566,0.003059,0.055038,0.002114,7.385810,0.576536,1.670656,0.443405,124.192344,28.865290,54.514746,29.343981,0.630659,0.088711,0.097186,0.043256,7.071493,0.446851,2.869144,0.381099,111.674981,12.128530,68.805372,13.482229,0.503544,0.030213,0.157110,0.055130


In [14]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy']]].mean(axis=1)

In [15]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Train_MAE,Train_MSE,Train_R2,Valid_MAE,Valid_MSE,Valid_R2
77,UKGA59J3QK,15,2,2,2,10000,118.469565,1.371736,2.157603,0.071611,0.418886,0.023265,11.872508,0.807291,3.699457,0.329671,0.964750,0.003004,0.009559,0.001385,2.652895,0.030995,1.133281,0.030206,18.315311,0.353627,13.337732,0.548958,0.920696,0.002457,0.036196,0.002478,7.348685,0.591469,1.669167,0.427275,123.110924,28.046650,53.908191,28.308120,0.632311,0.090179,0.098601,0.044533,6.988116,0.448829,2.831726,0.384423,109.061491,11.579567,67.115304,12.992329,0.513900,0.023895,0.155290,0.058050,2.405249,15.093909,0.942723,7.168400,116.086208,0.573106
21,5S7DMPGMKJ,30,1,2,2,5000,43.484329,0.630491,2.005917,0.064912,0.388423,0.024762,11.815050,0.884715,3.634091,0.448954,0.965032,0.003020,0.008989,0.001284,2.349866,0.030474,0.964891,0.014975,15.092283,0.361832,9.817125,0.185838,0.931803,0.001800,0.029940,0.001243,7.408960,0.530388,1.625273,0.472742,123.563797,28.004225,52.963667,29.727039,0.631342,0.083620,0.097075,0.042867,7.092686,0.386053,2.879806,0.347562,111.083546,9.049111,68.741997,11.291181,0.510709,0.028222,0.150654,0.054805,2.177892,13.453666,0.948418,7.250823,117.323672,0.571026
23,KWS6U9D9QZ,20,2,2,10,1000,11.973391,0.105834,2.278732,0.078942,0.421949,0.022559,15.062496,1.207470,4.400732,0.516264,0.955042,0.004160,0.012011,0.001815,2.720397,0.036598,1.080388,0.015809,19.959247,0.497962,12.765582,0.204601,0.906479,0.002407,0.044366,0.001633,7.361502,0.584816,1.671856,0.437386,123.581301,28.434253,54.209595,28.804044,0.631666,0.090402,0.097953,0.043798,7.017960,0.444607,2.835790,0.373992,109.870155,11.604260,67.476143,12.988897,0.509558,0.026681,0.156528,0.056786,2.499564,17.510871,0.930760,7.189731,116.725728,0.570612
123,JCGTSZ52CB,15,2,2,10,10000,117.898769,2.038102,2.619450,0.086253,0.488802,0.026498,17.744725,1.284154,5.295334,0.528232,0.947134,0.004529,0.014199,0.001964,3.149437,0.032181,1.296917,0.028178,25.167845,0.446880,17.130795,0.507940,0.886378,0.002828,0.051439,0.002388,7.370500,0.583502,1.671222,0.428319,123.554693,28.489770,54.210713,28.814079,0.631732,0.089528,0.098080,0.044034,7.039420,0.437749,2.851677,0.380355,110.184095,11.487301,67.779473,13.055270,0.509263,0.025829,0.155854,0.056961,2.884444,21.456285,0.916756,7.204960,116.869394,0.570498
115,HUSKQQQDWE,50,2,5,10,5000,56.506248,1.135015,2.644778,0.092697,0.496161,0.028587,20.730812,1.682675,6.272921,0.732382,0.938609,0.005663,0.015530,0.002276,3.097131,0.040136,1.244167,0.018341,25.954013,0.620934,16.567976,0.285809,0.879859,0.003110,0.054956,0.002111,7.385229,0.573802,1.674644,0.440958,124.173797,28.763832,54.588541,29.238884,0.630760,0.088751,0.097238,0.043442,7.070955,0.447902,2.869070,0.379944,111.657388,12.075436,68.798464,13.393204,0.503818,0.029729,0.156999,0.055080,2.870955,23.342413,0.909234,7.228092,117.915593,0.567289
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
46,P3AF9464CN,30,5,20,5,10000,198.531392,2.276956,4.333499,0.125969,0.802363,0.039455,50.590779,3.150419,16.457793,1.468712,0.852207,0.009818,0.034202,0.004188,4.790325,0.052797,1.910805,0.026349,57.650964,0.983040,36.450471,0.854469,0.738047,0.002263,0.111072,0.005059,7.734570,0.622075,1.805535,0.469639,136.936914,30.809410,61.313807,31.651977,0.597466,0.089385,0.094848,0.045376,7.498600,0.507893,3.014268,0.420703,126.504364,16.510835,77.160412,16.283083,0.436291,0.043801,0.173882,0.065707,4.561912,54.120872,0.795127,7.616585,131.720639,0.516879
0,00B7JPV9HJ,30,5,20,20,10000,200.494256,2.268230,4.333499,0.125969,0.802363,0.039455,50.590779,3.150419,16.457793,1.468712,0.852207,0.009818,0.034202,0.004188,4.790325,0.052797,1.910805,0.026349,57.650964,0.983040,36.450471,0.854469,0.738047,0.002263,0.111072,0.005059,7.734570,0.622075,1.805535,0.469639,136.936914,30.809410,61.313807,31.651977,0.597466,0.089385,0.094848,0.045376,7.498600,0.507893,3.014268,0.420703,126.504364,16.510835,77.160412,16.283083,0.436291,0.043801,0.173882,0.065707,4.561912,54.120872,0.795127,7.616585,131.720639,0.516879
22,5VBWHOO8TM,15,5,20,5,1000,20.441718,0.309830,4.362396,0.126133,0.807544,0.038292,50.894939,3.152749,16.584643,1.456533,0.851332,0.009744,0.034444,0.004098,4.822388,0.050321,1.927092,0.031391,58.169050,0.998682,36.885601,0.974538,0.736124,0.002586,0.111683,0.005196,7.741880,0.626373,1.807981,0.472551,136.980711,30.852891,61.371546,31.677136,0.597349,0.089474,0.094993,0.045420,7.499890,0.503741,3.016281,0.419927,126.543848,16.342879,77.229363,16.135218,0.436335,0.043654,0.173746,0.065958,4.592392,54.531994,0.793728,7.620885,131.762280,0.516842
116,HXHRK47HLB,30,5,20,2,700,16.831742,0.447243,4.335075,0.125301,0.801495,0.039831,50.615379,3.162326,16.446888,1.480137,0.852102,0.009799,0.034284,0.004176,4.791343,0.052708,1.910875,0.027461,57.693040,0.996307,36.483837,0.878078,0.737901,0.002266,0.111121,0.004958,7.740552,0.626494,1.805166,0.472099,136.982220,30.845310,61.243515,31.646843,0.597102,0.089434,0.095208,0.045523,7.497869,0.509384,3.016906,0.420792,126.605169,16.500763,77.319743,16.239146,0.436344,0.043925,0.173959,0.065431,4.563209,54.154210,0.795001,7.619211,131.793695,0.516723


In [16]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'max_depth': 15, 'max_features': 2, 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 10000}


## Best model

In [17]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_train_{}.npy'.format(DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_test_{}.npy'.format(DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_train_{}.npy'.format(DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_test_{}.npy'.format(DATA_ID)))

In [18]:
# Setup the model with the best parameters
model = RandomForestRegressor(**best_params, random_state=0, n_jobs=-1, verbose=1)

model.fit(X_train, Y_train)

# Save the model
dump(model, os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_best_model_{}_{}.joblib'.format(HS_DATE, DATA_ID))) 

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    0.6s
[Parallel(n_jobs=-1)]: Done 184 tasks      | elapsed:    2.2s
[Parallel(n_jobs=-1)]: Done 434 tasks      | elapsed:    5.2s
[Parallel(n_jobs=-1)]: Done 784 tasks      | elapsed:    9.3s
[Parallel(n_jobs=-1)]: Done 1234 tasks      | elapsed:   14.3s
[Parallel(n_jobs=-1)]: Done 1784 tasks      | elapsed:   20.5s
[Parallel(n_jobs=-1)]: Done 2434 tasks      | elapsed:   28.2s
[Parallel(n_jobs=-1)]: Done 3184 tasks      | elapsed:   37.2s
[Parallel(n_jobs=-1)]: Done 4034 tasks      | elapsed:   47.5s
[Parallel(n_jobs=-1)]: Done 4984 tasks      | elapsed:   59.5s
[Parallel(n_jobs=-1)]: Done 6034 tasks      | elapsed:  1.2min
[Parallel(n_jobs=-1)]: Done 7184 tasks      | elapsed:  1.5min
[Parallel(n_jobs=-1)]: Done 8434 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done 9784 tasks      | elapsed:  2.0min
[Parallel(n_jobs=-1)]: Done 10000 out of 

['../../../../results/0007_19072021/RF_21072021/RF_best_model_21072021_0007_19072021.joblib']

In [20]:
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

results = {
    'Train': {
        'MAE': mean_absolute_error(Y_train, train_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_train, train_preds, multioutput='raw_values'),
        'R2': r2_score(Y_train, train_preds, multioutput='raw_values')
    },
    'Test': {
        'MAE': mean_absolute_error(Y_test, test_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_test, test_preds, multioutput='raw_values'),
        'R2': r2_score(Y_test, test_preds, multioutput='raw_values')
    }       
    
}

# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results[subset][loss][i + f] for i in range(0, len(CELLS) * 2, 2)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done 434 tasks      | elapsed:    0.4s
[Parallel(n_jobs=8)]: Done 784 tasks      | elapsed:    0.6s
[Parallel(n_jobs=8)]: Done 1234 tasks      | elapsed:    0.9s
[Parallel(n_jobs=8)]: Done 1784 tasks      | elapsed:    1.4s
[Parallel(n_jobs=8)]: Done 2434 tasks      | elapsed:    1.9s
[Parallel(n_jobs=8)]: Done 3184 tasks      | elapsed:    2.5s
[Parallel(n_jobs=8)]: Done 4034 tasks      | elapsed:    3.1s
[Parallel(n_jobs=8)]: Done 4984 tasks      | elapsed:    3.8s
[Parallel(n_jobs=8)]: Done 6034 tasks      | elapsed:    4.6s
[Parallel(n_jobs=8)]: Done 7184 tasks      | elapsed:    5.4s
[Parallel(n_jobs=8)]: Done 8434 tasks      | elapsed:    6.3s
[Parallel(n_jobs=8)]: Done 9784 tasks      | elapsed:    7.2s
[Parallel(n_jobs=8)]: Done 10000 out of 10000 | elapsed:

Train Fx MAE: 2.2904 ± 0.4356
Train Fx MSE: 13.1327 ± 3.9882
Train Fx R2: 0.9611 ± 0.0104
Train Fy MAE: 2.7660 ± 1.1741
Train Fy MSE: 19.6227 ± 14.1450
Train Fy R2: 0.9152 ± 0.0379
Test Fx MAE: 7.9542 ± 2.5406
Test Fx MSE: 164.2414 ± 107.3285
Test Fx R2: 0.6927 ± 0.0674
Test Fy MAE: 7.0967 ± 2.7796
Test Fy MSE: 117.8615 ± 70.1520
Test Fy R2: 0.5324 ± 0.1616


[Parallel(n_jobs=8)]: Done 9784 tasks      | elapsed:    2.2s
[Parallel(n_jobs=8)]: Done 10000 out of 10000 | elapsed:    2.3s finished


In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()