In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
from joblib import dump, load
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.ensemble import RandomForestRegressor

## Configuration

In [2]:
# Force cells in the robotic leg
CELLS = [3, 4, 7, 8]

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0009_02082021'
# Hyperparameters search date
HS_DATE = '02082021'
# Number of folds in cross-validation
CV = 4

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0009_02082021


## Hyperparameters seach analysis

In [27]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_{}_*.json'.format(HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 71


In [28]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std
0,EFKWJJFNHT,30,0.7,0.00010,0.00001,10000,87.826264,1.978021,1.172446,0.018250,0.261492,0.014202,4.003900,0.150542,1.312194,0.078055,0.987035,0.000787,0.004713,0.000338,1.747271,0.065975,0.820949,0.028413,9.315967,0.662566,6.989312,0.381400,0.964088,0.002194,0.018391,0.001256,6.002620,0.849992,1.438615,0.195884,76.749590,16.840850,30.502377,8.892280,0.756182,0.019806,0.079778,0.025491,7.068334,1.300428,3.053066,0.530674,119.033924,36.511848,77.304245,22.500995,0.567718,0.090065,0.172376,0.037157
1,QW9QDGP6CW,15,0.3,0.00001,0.00100,700,2.825932,0.042438,2.426092,0.031735,0.545605,0.024702,13.427787,0.379877,4.753839,0.289008,0.957132,0.002514,0.014843,0.001151,3.538743,0.125681,1.737762,0.065770,30.686371,2.102745,23.890805,1.374248,0.886799,0.005604,0.056488,0.003606,6.161318,0.773641,1.548179,0.274627,78.221127,16.002408,32.891383,9.517012,0.755290,0.021504,0.069770,0.025025,7.353733,1.241015,3.273408,0.599013,122.354070,35.914177,81.422696,25.217758,0.567144,0.080620,0.168003,0.037230
2,0R8N9OMWEJ,13,0.3,0.00010,0.00100,1000,3.944760,0.015633,2.805114,0.039808,0.646109,0.032483,17.131198,0.492451,6.454831,0.445584,0.945668,0.003014,0.019214,0.001401,4.059644,0.141899,2.015141,0.080789,39.241174,2.686931,31.005679,1.893623,0.858871,0.006757,0.070187,0.004418,6.222016,0.744613,1.566285,0.266501,79.215215,15.899472,33.496822,9.559490,0.752142,0.021091,0.070294,0.024929,7.432588,1.222203,3.322450,0.598846,123.817327,35.563607,82.631259,25.254215,0.562487,0.078505,0.170266,0.036506
3,0WTQ7D73R8,50,0.1,0.00010,0.00010,1000,4.651081,0.093619,1.553767,0.027680,0.359439,0.014613,6.441418,0.254822,2.204873,0.131634,0.979342,0.001313,0.007187,0.000583,2.320804,0.092413,1.132334,0.040079,15.023870,1.098833,11.382840,0.653827,0.944108,0.003228,0.027650,0.001730,6.073829,0.807617,1.535684,0.254084,77.011942,16.085246,32.409192,9.228612,0.759157,0.022451,0.070560,0.025691,7.230331,1.274001,3.199982,0.601554,120.103978,36.090644,79.583537,24.830238,0.574120,0.082466,0.165818,0.037423
4,24YI1CO78X,15,0.5,0.00010,0.00010,10000,61.750179,0.785847,1.562541,0.014337,0.364602,0.027298,5.952858,0.120451,2.238018,0.243599,0.980942,0.000877,0.007324,0.000461,2.387174,0.077797,1.218519,0.046170,15.424672,1.076527,13.082637,0.811912,0.944373,0.003361,0.030558,0.002714,5.956317,0.797756,1.450640,0.207906,74.633694,15.830466,30.007257,8.686291,0.763201,0.019958,0.076120,0.024958,7.067321,1.275406,3.106906,0.554485,116.630600,35.608499,76.923187,23.050867,0.579219,0.086340,0.168853,0.037389
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
66,Y6AFCEKT6Y,30,0.5,0.00001,0.00001,700,5.458307,0.121644,0.714037,0.008806,0.171160,0.008871,1.563920,0.048664,0.520453,0.034455,0.994918,0.000281,0.001930,0.000159,1.146223,0.042741,0.603427,0.019461,4.288170,0.296937,3.541894,0.183880,0.984736,0.000969,0.008308,0.000635,5.919122,0.834883,1.428775,0.189305,74.256157,16.064973,29.343132,8.048708,0.763855,0.019865,0.077041,0.024735,7.005714,1.278590,3.049970,0.547455,115.972600,35.468850,75.866861,22.568123,0.579876,0.085684,0.167697,0.036880
67,YWMSK9ER2J,50,0.2,0.00010,0.00100,10000,42.309271,0.761295,2.286415,0.049656,0.507291,0.017843,12.921872,0.557993,4.368824,0.218166,0.958624,0.002648,0.013993,0.001158,3.323963,0.143389,1.585221,0.059932,28.556495,2.217521,21.129408,1.350318,0.892472,0.006031,0.052155,0.003059,6.148572,0.788574,1.541901,0.263821,78.268046,16.049297,32.916748,9.378624,0.755227,0.022557,0.070199,0.025562,7.352808,1.261500,3.267340,0.612684,123.167886,36.644633,81.856046,25.947388,0.564869,0.081547,0.168082,0.036815
68,YXPTD58P28,50,0.1,0.00100,0.00001,700,2.636777,0.038429,3.462491,0.088765,0.802154,0.023530,28.254431,1.184276,10.621671,0.627734,0.911148,0.004939,0.028733,0.002370,4.941256,0.225408,2.374734,0.105301,60.506527,4.696738,43.964498,3.251492,0.780439,0.010429,0.099286,0.005788,6.420915,0.791822,1.649123,0.302834,84.331579,17.273302,37.130720,10.837968,0.739138,0.022632,0.069668,0.023398,7.718021,1.239563,3.469262,0.640747,134.648831,38.302209,90.186472,28.868913,0.532886,0.078709,0.173945,0.035661
69,ZDV06TT075,13,0.3,0.00010,0.00001,5000,21.131081,0.387559,2.417408,0.022345,0.584632,0.035773,12.485418,0.269796,5.148277,0.479500,0.960675,0.002168,0.014834,0.001092,3.566045,0.107985,1.832828,0.075842,30.663644,1.926456,25.866620,1.559371,0.893900,0.004742,0.057840,0.003830,6.153767,0.745221,1.556604,0.261466,77.804553,15.637780,32.768276,9.387405,0.756312,0.020613,0.070089,0.025196,7.329914,1.216034,3.277947,0.584223,121.117394,35.157038,80.923466,24.516881,0.571922,0.078917,0.167780,0.037051


In [29]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy']]].mean(axis=1)

In [30]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Train_MAE,Train_MSE,Train_R2,Valid_MAE,Valid_MSE,Valid_R2
26,CV543FJ4KW,15,0.5,0.00001,0.00010,5000,33.448883,0.251140,1.249648,0.016483,0.305817,0.026046,3.597918,0.110154,1.565886,0.223677,0.988602,0.000489,0.004857,0.000340,1.973884,0.058913,1.068113,0.041720,10.764159,0.769740,10.192201,0.713509,0.963828,0.002418,0.023631,0.002344,5.928667,0.803547,1.434067,0.193857,74.161113,15.864415,29.366410,8.152799,0.764000,0.019464,0.076850,0.025159,7.023453,1.272321,3.073703,0.539962,115.487024,35.378176,75.792092,22.404599,0.581253,0.087487,0.167888,0.037745,1.611766,7.181039,0.976215,6.476060,94.824068,0.672626
66,Y6AFCEKT6Y,30,0.5,0.00001,0.00001,700,5.458307,0.121644,0.714037,0.008806,0.171160,0.008871,1.563920,0.048664,0.520453,0.034455,0.994918,0.000281,0.001930,0.000159,1.146223,0.042741,0.603427,0.019461,4.288170,0.296937,3.541894,0.183880,0.984736,0.000969,0.008308,0.000635,5.919122,0.834883,1.428775,0.189305,74.256157,16.064973,29.343132,8.048708,0.763855,0.019865,0.077041,0.024735,7.005714,1.278590,3.049970,0.547455,115.972600,35.468850,75.866861,22.568123,0.579876,0.085684,0.167697,0.036880,0.930130,2.926045,0.989827,6.462418,95.114379,0.671865
34,GXQ2P2SDKK,50,0.1,0.00001,0.00010,5000,25.105668,0.745796,0.823288,0.012226,0.200387,0.008926,1.914216,0.066072,0.676675,0.045962,0.993865,0.000368,0.002199,0.000194,1.292848,0.050151,0.673672,0.023415,5.018951,0.356296,4.040549,0.224634,0.982219,0.001032,0.009274,0.000642,6.008754,0.819841,1.511815,0.242382,75.865850,16.087144,31.565586,8.896900,0.762199,0.021740,0.071195,0.025631,7.126002,1.263656,3.142690,0.586029,117.486985,35.511033,77.624461,23.970755,0.581328,0.081824,0.164059,0.038013,1.058068,3.466583,0.988042,6.567378,96.676417,0.671763
47,MWLHV0Q0ZX,20,0.2,0.00001,0.00010,700,3.757856,0.069571,0.908873,0.006665,0.223512,0.013164,2.155692,0.040295,0.807204,0.071981,0.993131,0.000390,0.002497,0.000161,1.424768,0.042072,0.747860,0.029397,5.798231,0.342509,4.840678,0.257263,0.979782,0.000970,0.011004,0.000830,6.009874,0.817622,1.515496,0.240315,75.774510,15.939857,31.446591,8.783020,0.762110,0.022045,0.071709,0.026229,7.127898,1.258928,3.138187,0.582860,117.248960,35.234074,77.388521,23.830030,0.581130,0.081794,0.164495,0.038080,1.166821,3.976961,0.986456,6.568886,96.511735,0.671620
29,FFYN7HARRS,30,0.3,0.00001,0.00001,2500,13.748870,0.224771,0.824380,0.011918,0.200722,0.008974,1.920824,0.064670,0.679534,0.046792,0.993845,0.000364,0.002208,0.000200,1.293974,0.049869,0.673959,0.023264,5.034824,0.356388,4.049318,0.221905,0.982166,0.001043,0.009288,0.000652,6.013865,0.823339,1.513602,0.243905,76.011960,16.152667,31.601735,8.871752,0.761754,0.021536,0.071403,0.025624,7.126947,1.268465,3.143961,0.585747,117.476754,35.596499,77.621708,23.944304,0.581077,0.082096,0.164437,0.038199,1.059177,3.477824,0.988006,6.570406,96.744357,0.671415
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25,CCMTJO3MD1,20,0.1,0.00100,0.00010,700,2.609989,0.033342,3.462237,0.088739,0.802044,0.023868,28.242584,1.182303,10.613197,0.638262,0.911174,0.004933,0.028747,0.002365,4.941011,0.225529,2.374581,0.105910,60.494584,4.707265,43.959677,3.267609,0.780406,0.010369,0.099339,0.005816,6.419189,0.790585,1.647694,0.299626,84.290558,17.248092,37.071986,10.758853,0.739210,0.022600,0.069660,0.023300,7.718355,1.239065,3.467659,0.641105,134.580626,38.268388,90.081146,28.869188,0.532837,0.078649,0.173953,0.035764,4.201624,44.368584,0.845790,7.068772,109.435592,0.636024
53,QF78QEF8VW,20,0.2,0.00100,0.00001,700,2.594873,0.033641,3.462237,0.088739,0.802044,0.023868,28.242584,1.182303,10.613197,0.638262,0.911174,0.004933,0.028747,0.002365,4.941011,0.225529,2.374581,0.105910,60.494584,4.707265,43.959677,3.267609,0.780406,0.010369,0.099339,0.005816,6.419189,0.790585,1.647694,0.299626,84.290558,17.248092,37.071986,10.758853,0.739210,0.022600,0.069660,0.023300,7.718355,1.239065,3.467659,0.641105,134.580626,38.268388,90.081146,28.869188,0.532837,0.078649,0.173953,0.035764,4.201624,44.368584,0.845790,7.068772,109.435592,0.636024
68,YXPTD58P28,50,0.1,0.00100,0.00001,700,2.636777,0.038429,3.462491,0.088765,0.802154,0.023530,28.254431,1.184276,10.621671,0.627734,0.911148,0.004939,0.028733,0.002370,4.941256,0.225408,2.374734,0.105301,60.506527,4.696738,43.964498,3.251492,0.780439,0.010429,0.099286,0.005788,6.420915,0.791822,1.649123,0.302834,84.331579,17.273302,37.130720,10.837968,0.739138,0.022632,0.069668,0.023398,7.718021,1.239563,3.469262,0.640747,134.648831,38.302209,90.186472,28.868913,0.532886,0.078709,0.173945,0.035661,4.201874,44.380479,0.845793,7.069468,109.490205,0.636012
60,SZHI9U6IQO,13,0.2,0.00100,0.00001,2500,9.087514,0.188091,3.612291,0.079700,0.838990,0.029327,29.612062,1.079130,11.335785,0.663465,0.907009,0.005129,0.030373,0.002301,5.131491,0.213578,2.483685,0.108129,63.456225,4.649300,46.604723,3.229108,0.771202,0.010131,0.103788,0.006303,6.423146,0.755573,1.644586,0.300609,84.147618,16.835162,36.820328,10.472859,0.738977,0.022948,0.070325,0.023505,7.736785,1.222422,3.483841,0.629312,134.575124,37.707422,90.231051,28.446553,0.532481,0.077704,0.175059,0.035765,4.371891,46.534144,0.839105,7.079966,109.361371,0.635729


In [31]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'max_depth': 15, 'max_features': 0.5, 'min_samples_leaf': 1e-05, 'min_samples_split': 0.0001, 'n_estimators': 5000}


## Best model

In [17]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_train_{}.npy'.format(DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_test_{}.npy'.format(DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_train_{}.npy'.format(DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_test_{}.npy'.format(DATA_ID)))

In [18]:
# Setup the model with the best parameters
model = RandomForestRegressor(**best_params, random_state=0, n_jobs=-1, verbose=1)

model.fit(X_train, Y_train)

# Save the model
dump(model, os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_best_model_{}_{}.joblib'.format(HS_DATE, DATA_ID))) 

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    0.6s
[Parallel(n_jobs=-1)]: Done 184 tasks      | elapsed:    2.2s
[Parallel(n_jobs=-1)]: Done 434 tasks      | elapsed:    5.2s
[Parallel(n_jobs=-1)]: Done 784 tasks      | elapsed:    9.3s
[Parallel(n_jobs=-1)]: Done 1234 tasks      | elapsed:   14.3s
[Parallel(n_jobs=-1)]: Done 1784 tasks      | elapsed:   20.5s
[Parallel(n_jobs=-1)]: Done 2434 tasks      | elapsed:   28.2s
[Parallel(n_jobs=-1)]: Done 3184 tasks      | elapsed:   37.2s
[Parallel(n_jobs=-1)]: Done 4034 tasks      | elapsed:   47.5s
[Parallel(n_jobs=-1)]: Done 4984 tasks      | elapsed:   59.5s
[Parallel(n_jobs=-1)]: Done 6034 tasks      | elapsed:  1.2min
[Parallel(n_jobs=-1)]: Done 7184 tasks      | elapsed:  1.5min
[Parallel(n_jobs=-1)]: Done 8434 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done 9784 tasks      | elapsed:  2.0min
[Parallel(n_jobs=-1)]: Done 10000 out of 

['../../../../results/0007_19072021/RF_21072021/RF_best_model_21072021_0007_19072021.joblib']

In [20]:
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

results = {
    'Train': {
        'MAE': mean_absolute_error(Y_train, train_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_train, train_preds, multioutput='raw_values'),
        'R2': r2_score(Y_train, train_preds, multioutput='raw_values')
    },
    'Test': {
        'MAE': mean_absolute_error(Y_test, test_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_test, test_preds, multioutput='raw_values'),
        'R2': r2_score(Y_test, test_preds, multioutput='raw_values')
    }       
    
}

# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results[subset][loss][i + f] for i in range(0, len(CELLS) * 2, 2)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done 434 tasks      | elapsed:    0.4s
[Parallel(n_jobs=8)]: Done 784 tasks      | elapsed:    0.6s
[Parallel(n_jobs=8)]: Done 1234 tasks      | elapsed:    0.9s
[Parallel(n_jobs=8)]: Done 1784 tasks      | elapsed:    1.4s
[Parallel(n_jobs=8)]: Done 2434 tasks      | elapsed:    1.9s
[Parallel(n_jobs=8)]: Done 3184 tasks      | elapsed:    2.5s
[Parallel(n_jobs=8)]: Done 4034 tasks      | elapsed:    3.1s
[Parallel(n_jobs=8)]: Done 4984 tasks      | elapsed:    3.8s
[Parallel(n_jobs=8)]: Done 6034 tasks      | elapsed:    4.6s
[Parallel(n_jobs=8)]: Done 7184 tasks      | elapsed:    5.4s
[Parallel(n_jobs=8)]: Done 8434 tasks      | elapsed:    6.3s
[Parallel(n_jobs=8)]: Done 9784 tasks      | elapsed:    7.2s
[Parallel(n_jobs=8)]: Done 10000 out of 10000 | elapsed:

Train Fx MAE: 2.2904 ± 0.4356
Train Fx MSE: 13.1327 ± 3.9882
Train Fx R2: 0.9611 ± 0.0104
Train Fy MAE: 2.7660 ± 1.1741
Train Fy MSE: 19.6227 ± 14.1450
Train Fy R2: 0.9152 ± 0.0379
Test Fx MAE: 7.9542 ± 2.5406
Test Fx MSE: 164.2414 ± 107.3285
Test Fx R2: 0.6927 ± 0.0674
Test Fy MAE: 7.0967 ± 2.7796
Test Fy MSE: 117.8615 ± 70.1520
Test Fy R2: 0.5324 ± 0.1616


[Parallel(n_jobs=8)]: Done 9784 tasks      | elapsed:    2.2s
[Parallel(n_jobs=8)]: Done 10000 out of 10000 | elapsed:    2.3s finished


In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()