In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
from joblib import dump, load
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.ensemble import RandomForestRegressor

## Configuration

In [2]:
# Number of force cells in the robotic leg
N_CELLS = 8

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0006_30062021'
# Hyperparameters search date
HS_DATE = '01072021'
# Number of folds in cross-validation
CV = 6

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0006_30062021


## Hyperparameters seach analysis

In [25]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_{}_*.json'.format(HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 788


In [26]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Valid_Fz_MAE_mean__mean,Valid_Fz_MAE_mean__std,Valid_Fz_MAE_std__mean,Valid_Fz_MAE_std__std,Valid_Fz_MSE_mean__mean,Valid_Fz_MSE_mean__std,Valid_Fz_MSE_std__mean,Valid_Fz_MSE_std__std,Valid_Fz_R2_mean__mean,Valid_Fz_R2_mean__std,Valid_Fz_R2_std__mean,Valid_Fz_R2_std__std
0,J2B8WWA60A,13,0.01,0.010,0.0001,500,1.628091,0.015680,9.633375,0.116903,3.379386,0.109681,183.452179,5.226086,133.790528,7.307454,0.436990,0.009628,0.116463,0.003827,7.693463,0.059237,4.444596,0.078595,167.484605,4.281618,178.595138,7.717352,0.346292,0.004293,0.061686,0.002588,15.916403,0.207966,6.151072,0.143095,502.570743,9.729472,437.966848,12.598493,0.336618,0.007068,0.157754,0.005198,10.238879,0.515371,3.780939,0.662211,208.003110,30.967329,162.135557,52.373896,0.350398,0.035872,0.146517,0.023933,8.103008,0.243362,4.774938,0.294386,187.424325,22.163635,220.114660,40.277787,0.263947,0.021040,0.093315,0.018450,16.918744,0.957662,6.756130,0.687986,570.952406,65.823592,512.361381,101.162489,0.216110,0.042635,0.216298,0.049153
1,0151K82W8D,7,0.05,0.100,0.0001,100,0.193335,0.005371,11.063047,0.145711,3.609706,0.155726,248.088821,6.810880,166.146076,10.309018,0.233048,0.004383,0.084722,0.001813,8.854367,0.074427,5.195281,0.072967,228.735606,4.553998,248.452396,6.194865,0.129941,0.004047,0.053113,0.002708,18.056416,0.233472,7.279606,0.135973,669.688494,14.868216,660.558996,24.718440,0.160206,0.003596,0.114593,0.003298,11.248976,0.656502,3.846822,0.715427,256.300132,33.671181,181.748385,57.846630,0.191288,0.020930,0.102938,0.020506,8.958685,0.345336,5.338321,0.227219,233.289750,22.876753,273.852186,31.276344,0.093555,0.023122,0.065363,0.015023,18.332648,1.108402,7.538675,0.697567,687.178268,80.526323,679.561239,144.473501,0.102899,0.038464,0.158383,0.045133
2,03TKM9QHB2,13,0.05,0.010,0.0001,1000,3.179395,0.016767,9.631186,0.116787,3.380743,0.110199,183.330471,5.271895,133.882957,7.332289,0.437482,0.009630,0.116896,0.003874,7.691312,0.058739,4.445027,0.079654,167.244094,4.407054,178.358857,7.912230,0.347411,0.004598,0.061832,0.002450,15.914497,0.209780,6.150102,0.143707,502.414202,9.778289,437.733301,12.349140,0.336793,0.006972,0.157994,0.005161,10.235621,0.516107,3.782166,0.660774,207.822916,30.966901,162.153486,52.420460,0.351019,0.036524,0.146973,0.023832,8.100066,0.240972,4.773637,0.295441,187.272752,22.105987,219.980871,40.124878,0.264784,0.021377,0.093705,0.018867,16.913470,0.962320,6.749786,0.689486,570.837150,66.064478,512.209594,100.804959,0.216149,0.043457,0.216808,0.049031
3,04859OEBUO,15,0.20,0.010,0.0100,5000,15.801500,0.077439,9.637058,0.117997,3.381605,0.111385,183.634844,5.310259,134.070954,7.363247,0.436614,0.009549,0.116676,0.003758,7.689977,0.058839,4.443557,0.080914,167.408811,4.400836,178.618424,7.913807,0.347363,0.004746,0.061401,0.002711,15.922527,0.211654,6.155343,0.144878,503.191451,10.086524,439.178084,12.932152,0.336102,0.006898,0.157849,0.005058,10.239121,0.517463,3.781510,0.664056,208.079121,31.031983,162.335040,52.600005,0.350389,0.036270,0.146713,0.023540,8.098673,0.240839,4.772881,0.294451,187.436508,22.257681,220.237499,40.265598,0.264913,0.021081,0.093276,0.019437,16.917509,0.958306,6.753890,0.685171,571.255278,65.211777,513.290425,99.725994,0.215913,0.043027,0.216624,0.048539
4,0931K6BAVH,7,0.10,0.100,0.0010,5000,9.314562,0.072631,11.085357,0.142714,3.606290,0.154221,249.258065,6.819002,166.474732,10.242227,0.229979,0.004420,0.081540,0.002171,8.825319,0.068262,5.166870,0.071326,227.505374,4.409286,246.787375,6.079326,0.132899,0.004609,0.053994,0.002857,18.068303,0.231697,7.297246,0.134362,671.684385,14.455567,665.802072,24.052524,0.158882,0.003524,0.114950,0.003472,11.272328,0.664862,3.839669,0.708257,257.478339,33.740412,181.942306,57.474815,0.187888,0.021088,0.098909,0.019902,8.930798,0.337251,5.309995,0.229006,232.162745,22.825110,272.473943,31.525708,0.096370,0.023712,0.066744,0.014915,18.342988,1.110457,7.557588,0.698469,689.316738,81.697885,685.253148,147.484010,0.101772,0.038739,0.158249,0.043753
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
783,CDTFDCZXCF,7,0.20,0.010,0.0001,1000,2.946488,0.020424,9.740921,0.117150,3.427390,0.116934,186.892622,5.425737,136.815692,7.668133,0.426366,0.008919,0.119348,0.003705,7.841077,0.058125,4.543714,0.081560,172.467787,4.550836,183.576865,8.298512,0.327177,0.005279,0.067598,0.002716,16.114865,0.209244,6.240047,0.146637,511.961258,10.015419,444.583579,12.288027,0.323142,0.005965,0.160832,0.005442,10.272429,0.524030,3.793847,0.671589,209.194225,31.546990,162.918137,53.535267,0.346832,0.035777,0.146467,0.023375,8.191642,0.252944,4.839734,0.293717,190.371607,22.370413,222.861628,39.934912,0.252966,0.021910,0.092988,0.018098,16.974632,0.983184,6.784972,0.710488,573.663156,66.505489,515.970798,101.582179,0.214066,0.041510,0.212920,0.047068
784,CE40FHSD6T,5,0.05,0.010,0.0100,100,0.276744,0.011434,10.034749,0.122733,3.562851,0.135627,198.578809,5.886994,147.440752,8.814619,0.392124,0.007542,0.122652,0.002645,8.105753,0.058765,4.702596,0.083562,184.910377,4.952511,197.598634,9.150238,0.282316,0.006355,0.073763,0.002853,16.602499,0.214493,6.491210,0.154691,541.404835,12.279445,475.906822,17.770542,0.286856,0.004254,0.162117,0.005178,10.426914,0.565142,3.866368,0.708549,215.772083,33.570160,169.049597,57.102885,0.328194,0.036530,0.145808,0.022483,8.354451,0.280391,4.936336,0.279233,197.383192,22.600818,231.406855,38.797672,0.225156,0.022031,0.092776,0.014692,17.233920,1.038497,6.940977,0.708269,588.949754,67.055431,535.306424,106.071060,0.198486,0.039280,0.206055,0.045332
785,CGOQXQSH9V,5,0.05,0.100,0.0001,500,0.951243,0.005840,11.063924,0.145091,3.601860,0.155607,248.259811,6.941509,165.553413,10.343375,0.232339,0.004739,0.083405,0.001956,8.852474,0.070918,5.189707,0.071375,228.707996,4.443764,248.219156,6.049831,0.129188,0.004400,0.052953,0.003096,18.061306,0.235907,7.288126,0.134095,670.885853,14.381263,663.099472,23.238742,0.159463,0.003748,0.114291,0.003846,11.253594,0.662460,3.836185,0.714306,256.604568,33.770878,181.071884,57.724196,0.189972,0.020838,0.101053,0.019915,8.958647,0.346673,5.334121,0.229227,233.457117,23.031321,273.823727,31.717192,0.092374,0.023243,0.064948,0.015188,18.341113,1.108534,7.550649,0.698466,688.870020,81.089224,682.980316,145.487751,0.101843,0.038427,0.157647,0.044714
786,CH5RU53N41,10,0.01,0.001,0.0100,500,3.535203,0.271417,8.956884,0.098409,3.139601,0.097935,155.900792,4.233398,114.746735,5.953734,0.519229,0.009561,0.117905,0.003860,7.183106,0.061657,4.068786,0.091295,137.943190,4.057779,142.538660,7.801376,0.432150,0.005511,0.075540,0.003320,14.674314,0.187562,5.495889,0.134132,412.247926,7.482473,319.560421,7.269180,0.432161,0.006886,0.163999,0.006899,9.976846,0.439425,3.732942,0.636045,196.157418,28.178674,156.270947,48.429771,0.385618,0.045630,0.159794,0.024445,7.875526,0.259355,4.616516,0.336529,174.027777,24.178227,202.765023,48.406552,0.305385,0.020803,0.105317,0.022462,16.388148,0.816661,6.407908,0.603444,525.083271,50.776747,439.203373,73.604825,0.253461,0.046762,0.236488,0.054618


In [27]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy', 'Fz']]].mean(axis=1)

In [28]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Valid_Fz_MAE_mean__mean,Valid_Fz_MAE_mean__std,Valid_Fz_MAE_std__mean,Valid_Fz_MAE_std__std,Valid_Fz_MSE_mean__mean,Valid_Fz_MSE_mean__std,Valid_Fz_MSE_std__mean,Valid_Fz_MSE_std__std,Valid_Fz_R2_mean__mean,Valid_Fz_R2_mean__std,Valid_Fz_R2_std__mean,Valid_Fz_R2_std__std,Train_MAE,Train_MSE,Train_R2,Valid_MAE,Valid_MSE,Valid_R2
111,NZO5TE3B4L,13,0.05,0.001,0.0001,10000,43.780256,0.383255,7.677225,0.105979,2.700615,0.088058,114.687254,3.408697,85.776329,4.228359,0.645787,0.009466,0.096529,0.002411,6.079501,0.062695,3.343031,0.090165,95.155414,3.075837,95.230762,5.946895,0.585452,0.005064,0.071672,0.002682,12.452204,0.200666,4.547755,0.132292,291.771144,6.514146,201.701731,5.292893,0.581715,0.008151,0.140559,0.006634,9.795947,0.385915,3.741132,0.593070,190.400700,25.896013,156.721497,43.107825,0.404354,0.058445,0.171829,0.029027,7.593492,0.276075,4.407584,0.397541,162.533029,26.810420,190.174401,56.513626,0.344356,0.021445,0.116253,0.027182,16.005167,0.663544,6.181631,0.487822,495.155921,38.586024,388.197268,56.114377,0.271685,0.054570,0.262390,0.067865,8.736310,167.204604,0.604318,11.131535,282.696550,0.340132
669,IJL38Z9NAQ,13,0.01,0.001,0.0010,10000,43.949214,0.291127,7.677225,0.105979,2.700615,0.088058,114.687254,3.408697,85.776329,4.228359,0.645787,0.009466,0.096529,0.002411,6.079501,0.062695,3.343031,0.090165,95.155414,3.075837,95.230762,5.946895,0.585452,0.005064,0.071672,0.002682,12.452204,0.200666,4.547755,0.132292,291.771144,6.514146,201.701731,5.292893,0.581715,0.008151,0.140559,0.006634,9.795947,0.385915,3.741132,0.593070,190.400700,25.896013,156.721497,43.107825,0.404354,0.058445,0.171829,0.029027,7.593492,0.276075,4.407584,0.397541,162.533029,26.810420,190.174401,56.513626,0.344356,0.021445,0.116253,0.027182,16.005167,0.663544,6.181631,0.487822,495.155921,38.586024,388.197268,56.114377,0.271685,0.054570,0.262390,0.067865,8.736310,167.204604,0.604318,11.131535,282.696550,0.340132
228,L6BHB49WUS,13,0.05,0.001,0.0010,10000,44.035617,0.363289,7.677225,0.105979,2.700615,0.088058,114.687254,3.408697,85.776329,4.228359,0.645787,0.009466,0.096529,0.002411,6.079501,0.062695,3.343031,0.090165,95.155414,3.075837,95.230762,5.946895,0.585452,0.005064,0.071672,0.002682,12.452204,0.200666,4.547755,0.132292,291.771144,6.514146,201.701731,5.292893,0.581715,0.008151,0.140559,0.006634,9.795947,0.385915,3.741132,0.593070,190.400700,25.896013,156.721497,43.107825,0.404354,0.058445,0.171829,0.029027,7.593492,0.276075,4.407584,0.397541,162.533029,26.810420,190.174401,56.513626,0.344356,0.021445,0.116253,0.027182,16.005167,0.663544,6.181631,0.487822,495.155921,38.586024,388.197268,56.114377,0.271685,0.054570,0.262390,0.067865,8.736310,167.204604,0.604318,11.131535,282.696550,0.340132
520,Q9AZOD9FDJ,13,0.20,0.001,0.0010,10000,44.161971,0.626995,7.677225,0.105979,2.700615,0.088058,114.687254,3.408697,85.776329,4.228359,0.645787,0.009466,0.096529,0.002411,6.079501,0.062695,3.343031,0.090165,95.155414,3.075837,95.230762,5.946895,0.585452,0.005064,0.071672,0.002682,12.452204,0.200666,4.547755,0.132292,291.771144,6.514146,201.701731,5.292893,0.581715,0.008151,0.140559,0.006634,9.795947,0.385915,3.741132,0.593070,190.400700,25.896013,156.721497,43.107825,0.404354,0.058445,0.171829,0.029027,7.593492,0.276075,4.407584,0.397541,162.533029,26.810420,190.174401,56.513626,0.344356,0.021445,0.116253,0.027182,16.005167,0.663544,6.181631,0.487822,495.155921,38.586024,388.197268,56.114377,0.271685,0.054570,0.262390,0.067865,8.736310,167.204604,0.604318,11.131535,282.696550,0.340132
283,EV5823049O,13,0.10,0.001,0.0001,10000,44.313295,0.347843,7.677225,0.105979,2.700615,0.088058,114.687254,3.408697,85.776329,4.228359,0.645787,0.009466,0.096529,0.002411,6.079501,0.062695,3.343031,0.090165,95.155414,3.075837,95.230762,5.946895,0.585452,0.005064,0.071672,0.002682,12.452204,0.200666,4.547755,0.132292,291.771144,6.514146,201.701731,5.292893,0.581715,0.008151,0.140559,0.006634,9.795947,0.385915,3.741132,0.593070,190.400700,25.896013,156.721497,43.107825,0.404354,0.058445,0.171829,0.029027,7.593492,0.276075,4.407584,0.397541,162.533029,26.810420,190.174401,56.513626,0.344356,0.021445,0.116253,0.027182,16.005167,0.663544,6.181631,0.487822,495.155921,38.586024,388.197268,56.114377,0.271685,0.054570,0.262390,0.067865,8.736310,167.204604,0.604318,11.131535,282.696550,0.340132
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
265,R8VZG2XPJL,5,0.10,0.100,0.0100,500,0.976789,0.014837,11.063924,0.145091,3.601860,0.155607,248.259811,6.941509,165.553413,10.343375,0.232339,0.004739,0.083405,0.001956,8.852474,0.070918,5.189707,0.071375,228.707996,4.443764,248.219156,6.049831,0.129188,0.004400,0.052953,0.003096,18.061306,0.235907,7.288126,0.134095,670.885853,14.381263,663.099472,23.238742,0.159463,0.003748,0.114291,0.003846,11.253594,0.662460,3.836185,0.714306,256.604568,33.770878,181.071884,57.724196,0.189972,0.020838,0.101053,0.019915,8.958647,0.346673,5.334121,0.229227,233.457117,23.031321,273.823727,31.717192,0.092374,0.023243,0.064948,0.015188,18.341113,1.108534,7.550649,0.698466,688.870020,81.089224,682.980316,145.487751,0.101843,0.038427,0.157647,0.044714,12.659235,382.617887,0.173663,12.851118,392.977235,0.128063
471,XGRRD64A7B,10,0.05,0.100,0.0100,500,0.940387,0.016989,11.063924,0.145091,3.601860,0.155607,248.259811,6.941509,165.553413,10.343375,0.232339,0.004739,0.083405,0.001956,8.852474,0.070918,5.189707,0.071375,228.707996,4.443764,248.219156,6.049831,0.129188,0.004400,0.052953,0.003096,18.061306,0.235907,7.288126,0.134095,670.885853,14.381263,663.099472,23.238742,0.159463,0.003748,0.114291,0.003846,11.253594,0.662460,3.836185,0.714306,256.604568,33.770878,181.071884,57.724196,0.189972,0.020838,0.101053,0.019915,8.958647,0.346673,5.334121,0.229227,233.457117,23.031321,273.823727,31.717192,0.092374,0.023243,0.064948,0.015188,18.341113,1.108534,7.550649,0.698466,688.870020,81.089224,682.980316,145.487751,0.101843,0.038427,0.157647,0.044714,12.659235,382.617887,0.173663,12.851118,392.977235,0.128063
91,BOPTQ41V9G,7,0.01,0.100,0.0010,500,0.961275,0.011368,11.063924,0.145091,3.601860,0.155607,248.259811,6.941509,165.553413,10.343375,0.232339,0.004739,0.083405,0.001956,8.852474,0.070918,5.189707,0.071375,228.707996,4.443764,248.219156,6.049831,0.129188,0.004400,0.052953,0.003096,18.061306,0.235907,7.288126,0.134095,670.885853,14.381263,663.099472,23.238742,0.159463,0.003748,0.114291,0.003846,11.253594,0.662460,3.836185,0.714306,256.604568,33.770878,181.071884,57.724196,0.189972,0.020838,0.101053,0.019915,8.958647,0.346673,5.334121,0.229227,233.457117,23.031321,273.823727,31.717192,0.092374,0.023243,0.064948,0.015188,18.341113,1.108534,7.550649,0.698466,688.870020,81.089224,682.980316,145.487751,0.101843,0.038427,0.157647,0.044714,12.659235,382.617887,0.173663,12.851118,392.977235,0.128063
259,R2RWCBA9MM,5,0.20,0.100,0.0100,500,0.992487,0.021598,11.063924,0.145091,3.601860,0.155607,248.259811,6.941509,165.553413,10.343375,0.232339,0.004739,0.083405,0.001956,8.852474,0.070918,5.189707,0.071375,228.707996,4.443764,248.219156,6.049831,0.129188,0.004400,0.052953,0.003096,18.061306,0.235907,7.288126,0.134095,670.885853,14.381263,663.099472,23.238742,0.159463,0.003748,0.114291,0.003846,11.253594,0.662460,3.836185,0.714306,256.604568,33.770878,181.071884,57.724196,0.189972,0.020838,0.101053,0.019915,8.958647,0.346673,5.334121,0.229227,233.457117,23.031321,273.823727,31.717192,0.092374,0.023243,0.064948,0.015188,18.341113,1.108534,7.550649,0.698466,688.870020,81.089224,682.980316,145.487751,0.101843,0.038427,0.157647,0.044714,12.659235,382.617887,0.173663,12.851118,392.977235,0.128063


In [29]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'max_depth': 13, 'max_features': 0.05, 'min_samples_leaf': 0.001, 'min_samples_split': 0.0001, 'n_estimators': 10000}


## Best model

In [8]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_train_{}.npy'.format(DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_test_{}.npy'.format(DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_train_{}.npy'.format(DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_test_{}.npy'.format(DATA_ID)))

In [10]:
# Setup the model with the best parameters
model = RandomForestRegressor(**best_params, random_state=0, n_jobs=-1, verbose=1)

model.fit(X_train, Y_train)

# Save the model
dump(model, os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_best_model_{}_{}.joblib'.format(HS_DATE, DATA_ID))) 

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    1.9s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:    5.0s finished


['../../../results/0003_11042021/RF_12042021/RF_best_model_12042021_0003_11042021.joblib']

In [11]:
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

results = {
    'Train': {
        'MAE': mean_absolute_error(Y_train, train_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_train, train_preds, multioutput='raw_values'),
        'R2': r2_score(Y_train, train_preds, multioutput='raw_values')
    },
    'Test': {
        'MAE': mean_absolute_error(Y_test, test_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_test, test_preds, multioutput='raw_values'),
        'R2': r2_score(Y_test, test_preds, multioutput='raw_values')
    }       
    
}

# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy', 'Fz']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done 100 out of 100 | elapsed:    0.5s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done 100 out of 100 | elapsed:    0.1s finished


Train Fx MAE: 8.6461 ± 2.4048
Train Fx MSE: 151.6214 ± 77.8937
Train Fx R2: 0.5655 ± 0.0612
Train Fy MAE: 8.7525 ± 6.0748
Train Fy MSE: 238.7282 ± 265.7849
Train Fy R2: 0.4389 ± 0.1117
Train Fz MAE: 11.2891 ± 3.9985
Train Fz MSE: 300.2048 ± 246.7115
Train Fz R2: 0.5458 ± 0.0617
Test Fx MAE: 13.0032 ± 5.0952
Test Fx MSE: 311.5999 ± 249.7830
Test Fx R2: 0.3473 ± 0.2863
Test Fy MAE: 10.7406 ± 7.7562
Test Fy MSE: 378.7206 ± 481.6318
Test Fy R2: 0.3404 ± 0.1027
Test Fz MAE: 19.0024 ± 6.5255
Test Fz MSE: 687.2684 ± 593.5296
Test Fz R2: 0.3611 ± 0.2479


In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()