In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
from joblib import dump, load
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.ensemble import RandomForestRegressor

## Configuration

In [2]:
# Number of force cells in the robotic leg
N_CELLS = 8

# Path where the results are stored
RESULTS_PATH = '../../../results'
# ID of the training and test data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0003_11042021'
# Hyperparameters search date
HS_DATE = '12042021'
# Number of folds in cross-validation
CV = 6

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0003_11042021


## Hyperparameters seach analysis

In [3]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_{}_*.json'.format(HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 189


In [4]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Test_Fx_MAE_mean__mean,Test_Fx_MAE_mean__std,Test_Fx_MAE_std__mean,Test_Fx_MAE_std__std,Test_Fx_MSE_mean__mean,Test_Fx_MSE_mean__std,Test_Fx_MSE_std__mean,Test_Fx_MSE_std__std,Test_Fx_R2_mean__mean,Test_Fx_R2_mean__std,Test_Fx_R2_std__mean,Test_Fx_R2_std__std,Test_Fy_MAE_mean__mean,Test_Fy_MAE_mean__std,Test_Fy_MAE_std__mean,Test_Fy_MAE_std__std,Test_Fy_MSE_mean__mean,Test_Fy_MSE_mean__std,Test_Fy_MSE_std__mean,Test_Fy_MSE_std__std,Test_Fy_R2_mean__mean,Test_Fy_R2_mean__std,Test_Fy_R2_std__mean,Test_Fy_R2_std__std,Test_Fz_MAE_mean__mean,Test_Fz_MAE_mean__std,Test_Fz_MAE_std__mean,Test_Fz_MAE_std__std,Test_Fz_MSE_mean__mean,Test_Fz_MSE_mean__std,Test_Fz_MSE_std__mean,Test_Fz_MSE_std__std,Test_Fz_R2_mean__mean,Test_Fz_R2_mean__std,Test_Fz_R2_std__mean,Test_Fz_R2_std__std
0,36LZKCUXPX,13,0.2,0.010,0.0100,1000,38.531896,0.292728,8.293597,0.269200,2.320279,0.092058,139.426954,8.553446,70.675360,3.545023,0.597713,0.021375,0.060194,0.010535,8.491057,0.206025,5.914225,0.253472,223.625125,11.640757,250.455256,17.344791,0.468452,0.015957,0.111703,0.009999,10.782770,0.405617,3.841568,0.246229,272.251313,21.912704,220.031718,31.247252,0.581321,0.022405,0.068758,0.014736,12.248514,1.468944,4.288143,0.729241,299.741081,73.875420,199.434367,57.616807,-0.002072,0.295405,0.252135,0.060772,11.281859,1.817362,8.714458,1.739269,429.737166,136.841507,641.317243,180.387693,-0.063488,0.291136,0.413969,0.493538,16.220895,2.878738,6.707404,1.552511,599.830546,225.920637,574.990043,338.018971,-0.280532,0.311635,0.725191,0.439405
1,WF84XYZ7X6,15,0.2,0.001,0.0001,100,5.584132,0.055184,3.989016,0.216867,1.032845,0.068532,40.075266,4.704387,17.793164,2.011311,0.881916,0.012370,0.027325,0.005003,4.386793,0.090987,2.961924,0.114399,63.063597,2.735547,67.510873,4.402733,0.824512,0.010420,0.053039,0.004350,4.974872,0.216206,1.797059,0.134057,66.075079,6.296866,42.509821,6.346733,0.891425,0.008484,0.029403,0.003634,11.820579,1.762237,4.023486,0.974601,292.937572,90.624492,189.939796,80.309630,0.004759,0.339559,0.280261,0.113298,11.540467,1.832232,8.883276,1.977760,452.815338,149.954984,642.917186,188.426411,-0.212758,0.434563,0.660147,0.908696,15.699374,3.063488,6.706369,1.735746,590.315274,229.493714,568.157855,305.878718,-0.280355,0.409758,0.802198,0.633418
2,SDUEIILGAK,13,0.3,0.100,0.0100,5000,90.850062,3.026931,11.857662,0.221253,3.881044,0.143315,276.128594,8.270421,170.340212,5.352094,0.233624,0.028285,0.073646,0.012372,10.874762,0.290503,7.655420,0.351103,393.822840,18.403318,483.010909,32.685227,0.158749,0.022415,0.070464,0.012522,15.648575,0.519284,5.486778,0.333128,545.148437,43.343965,479.314836,77.110240,0.177893,0.027113,0.092930,0.010166,13.451048,1.117708,5.499440,0.852936,349.730815,56.666686,263.210643,63.415569,-0.102518,0.123107,0.285219,0.123479,11.502202,1.941588,8.898177,2.165522,452.338312,151.367316,728.910645,249.040555,-0.017960,0.132479,0.165428,0.097025,17.696312,2.791643,6.878257,1.283146,665.362901,263.778701,621.664918,439.415503,-0.436281,0.203747,0.780843,0.314696
3,OKC90NRKB5,7,0.3,0.100,0.0010,100,1.903570,0.064207,11.857920,0.225918,3.878876,0.145550,276.614825,8.029381,170.638422,5.185638,0.232178,0.028484,0.074727,0.012475,10.868781,0.299385,7.653613,0.367680,392.625320,19.024922,480.543125,33.574701,0.160403,0.021550,0.070124,0.012387,15.653111,0.527681,5.485920,0.327217,545.683071,43.542934,479.455911,76.349349,0.176916,0.029040,0.094153,0.010166,13.480208,1.187067,5.523776,0.880125,351.893082,60.079935,265.869189,67.540290,-0.106440,0.121928,0.284136,0.122868,11.505940,1.911590,8.887789,2.119950,452.207070,149.724227,726.686467,245.424040,-0.019566,0.133892,0.170532,0.108835,17.753505,2.859212,6.907565,1.270543,670.425852,270.763023,627.146726,446.394509,-0.445931,0.204270,0.790748,0.325540
4,1RRSAQUBZM,10,0.4,0.001,0.0010,10000,862.378149,6.169347,4.311347,0.193452,1.225088,0.072743,45.748409,4.701752,23.471781,2.282427,0.868400,0.011994,0.028348,0.005595,4.838870,0.099091,3.293868,0.135245,73.930657,3.515774,80.724645,5.962941,0.791701,0.007035,0.072507,0.004027,5.373902,0.158938,1.938601,0.128935,73.648607,5.514519,43.066574,5.421628,0.875566,0.008451,0.036786,0.003293,12.080452,1.911269,4.150511,1.127462,311.230412,102.321082,202.989136,95.541040,-0.059311,0.369267,0.326379,0.162195,11.668374,1.922511,9.076120,2.113338,475.970878,161.363347,683.552114,206.761362,-0.295472,0.558230,0.820828,1.216810,16.050455,3.113654,6.823079,1.941123,623.565193,236.844045,586.399848,297.535242,-0.385923,0.455062,0.920557,0.803510
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
184,WCVNOPSTE2,15,0.3,0.100,0.0010,5000,91.671640,3.358676,11.857662,0.221253,3.881044,0.143315,276.128594,8.270421,170.340212,5.352094,0.233624,0.028285,0.073646,0.012372,10.874762,0.290503,7.655420,0.351103,393.822840,18.403318,483.010909,32.685227,0.158749,0.022415,0.070464,0.012522,15.648575,0.519284,5.486778,0.333128,545.148437,43.343965,479.314836,77.110240,0.177893,0.027113,0.092930,0.010166,13.451048,1.117708,5.499440,0.852936,349.730815,56.666686,263.210643,63.415569,-0.102518,0.123107,0.285219,0.123479,11.502202,1.941588,8.898177,2.165522,452.338312,151.367316,728.910645,249.040555,-0.017960,0.132479,0.165428,0.097025,17.696312,2.791643,6.878257,1.283146,665.362901,263.778701,621.664918,439.415503,-0.436281,0.203747,0.780843,0.314696
185,ZCMMKVZY6K,7,0.3,0.100,0.0001,1000,18.656924,0.527760,11.859205,0.225506,3.889958,0.147542,276.302545,8.327470,170.830848,5.459960,0.233551,0.028380,0.073732,0.012421,10.871135,0.287677,7.654423,0.349526,393.406674,18.033768,482.339112,32.037101,0.159579,0.022156,0.070852,0.012295,15.648843,0.522107,5.483349,0.331850,545.035133,43.208157,478.873006,76.722992,0.177857,0.026942,0.093208,0.010053,13.449549,1.129481,5.505629,0.864459,349.850757,57.065213,263.846245,64.195166,-0.102315,0.122294,0.286201,0.124702,11.497578,1.938110,8.895531,2.161017,451.899712,150.978404,728.031807,248.212562,-0.017254,0.132858,0.166687,0.098508,17.692956,2.810656,6.879572,1.286399,665.599490,265.198396,622.203032,440.845645,-0.435512,0.203364,0.780792,0.316006
186,NLWKP59QTJ,7,0.4,0.100,0.0001,1000,24.944751,0.941889,11.729506,0.249569,3.795355,0.109389,271.626962,8.908975,165.045819,3.745452,0.243209,0.029967,0.076738,0.014941,10.784887,0.290793,7.550610,0.345789,382.622298,17.124490,461.345172,27.049530,0.166865,0.022041,0.077251,0.012441,15.583412,0.498891,5.526707,0.319760,542.580561,43.193193,479.465497,75.881634,0.185285,0.029451,0.092798,0.011786,13.504663,1.224099,5.495761,0.851675,353.142780,63.929066,264.281676,68.450513,-0.116794,0.141568,0.286167,0.127650,11.478115,1.940891,8.878837,1.952403,446.491285,144.425041,710.138837,224.001405,-0.022657,0.159297,0.211867,0.165853,17.859409,2.989189,7.009970,1.249628,678.793308,278.409161,637.298358,445.512447,-0.459729,0.222023,0.816078,0.361470
187,97MBEF35B7,15,0.3,0.010,0.0100,5000,238.648859,2.904252,7.930370,0.274705,2.184640,0.096536,129.553844,8.518714,64.503778,3.661191,0.624659,0.020501,0.061508,0.011288,8.271924,0.205529,5.737839,0.243633,211.220314,11.501014,233.833186,16.705599,0.488044,0.015060,0.116492,0.009504,10.315885,0.405746,3.686008,0.238592,253.275951,21.520286,201.838517,29.508107,0.608821,0.022298,0.068511,0.014638,12.138196,1.624612,4.206602,0.797400,298.827259,81.507177,198.096794,64.725576,-0.006122,0.322417,0.264639,0.081113,11.281710,1.835123,8.718260,1.688286,429.941508,139.623497,633.170657,170.045278,-0.095830,0.357519,0.505341,0.676751,16.160469,2.998770,6.729240,1.639488,602.424810,228.457876,573.064014,322.731863,-0.300819,0.349261,0.772221,0.515670


In [5]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Test']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy', 'Fz']]].mean(axis=1)

In [6]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Test_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Test_Fx_MAE_mean__mean,Test_Fx_MAE_mean__std,Test_Fx_MAE_std__mean,Test_Fx_MAE_std__std,Test_Fx_MSE_mean__mean,Test_Fx_MSE_mean__std,Test_Fx_MSE_std__mean,Test_Fx_MSE_std__std,Test_Fx_R2_mean__mean,Test_Fx_R2_mean__std,Test_Fx_R2_std__mean,Test_Fx_R2_std__std,Test_Fy_MAE_mean__mean,Test_Fy_MAE_mean__std,Test_Fy_MAE_std__mean,Test_Fy_MAE_std__std,Test_Fy_MSE_mean__mean,Test_Fy_MSE_mean__std,Test_Fy_MSE_std__mean,Test_Fy_MSE_std__std,Test_Fy_R2_mean__mean,Test_Fy_R2_mean__std,Test_Fy_R2_std__mean,Test_Fy_R2_std__std,Test_Fz_MAE_mean__mean,Test_Fz_MAE_mean__std,Test_Fz_MAE_std__mean,Test_Fz_MAE_std__std,Test_Fz_MSE_mean__mean,Test_Fz_MSE_mean__std,Test_Fz_MSE_std__mean,Test_Fz_MSE_std__std,Test_Fz_R2_mean__mean,Test_Fz_R2_mean__std,Test_Fz_R2_std__mean,Test_Fz_R2_std__std,Train_MAE,Train_MSE,Train_R2,Test_MAE,Test_MSE,Test_R2
159,1KGYR9IE8V,15,0.2,0.010,0.0100,100,3.907211,0.041695,8.250258,0.286284,2.293878,0.101546,138.328332,8.990843,69.711930,3.716997,0.600403,0.021955,0.060052,0.009706,8.462369,0.190287,5.884515,0.241281,222.166744,10.742809,248.041411,15.629646,0.470333,0.017016,0.112322,0.008769,10.739287,0.399360,3.824509,0.250649,271.001447,21.902364,218.370869,31.965609,0.583117,0.022469,0.069386,0.013905,12.212270,1.504909,4.245837,0.704140,298.890833,75.587088,197.595075,57.460764,-0.001022,0.307129,0.248707,0.064487,11.251960,1.819990,8.660540,1.690955,423.984063,136.331501,629.658575,177.589440,-0.054230,0.286978,0.401921,0.460208,16.232581,2.895135,6.712613,1.565798,599.850009,225.292725,575.037872,334.109376,-0.279363,0.317695,0.720009,0.428534,9.150638,210.498841,0.551284,13.232270,440.908302,-0.111538
109,T7S2Z8YC4Y,15,0.2,0.010,0.0010,100,3.974562,0.077967,8.250258,0.286284,2.293878,0.101546,138.328332,8.990843,69.711930,3.716997,0.600403,0.021955,0.060052,0.009706,8.462369,0.190287,5.884515,0.241281,222.166744,10.742809,248.041411,15.629646,0.470333,0.017016,0.112322,0.008769,10.739287,0.399360,3.824509,0.250649,271.001447,21.902364,218.370869,31.965609,0.583117,0.022469,0.069386,0.013905,12.212270,1.504909,4.245837,0.704140,298.890833,75.587088,197.595075,57.460764,-0.001022,0.307129,0.248707,0.064487,11.251960,1.819990,8.660540,1.690955,423.984063,136.331501,629.658575,177.589440,-0.054230,0.286978,0.401921,0.460208,16.232581,2.895135,6.712613,1.565798,599.850009,225.292725,575.037872,334.109376,-0.279363,0.317695,0.720009,0.428534,9.150638,210.498841,0.551284,13.232270,440.908302,-0.111538
25,04LYEPEJY2,13,0.2,0.010,0.0001,100,4.002557,0.091523,8.251228,0.285054,2.294159,0.100303,138.361733,8.950949,69.729092,3.648570,0.600288,0.021983,0.060050,0.009765,8.464398,0.189494,5.885984,0.240626,222.247240,10.734424,248.126032,15.611168,0.470232,0.016849,0.112188,0.008699,10.741285,0.397947,3.824706,0.250248,271.018307,21.723134,218.306379,31.719268,0.582983,0.022402,0.069468,0.013964,12.212815,1.513007,4.247029,0.707139,299.001789,75.823487,197.816911,57.605714,-0.001295,0.307930,0.249053,0.064459,11.253036,1.819916,8.660282,1.687628,424.057786,136.557716,629.576392,177.914503,-0.054332,0.287102,0.402061,0.460100,16.235623,2.899161,6.714920,1.563799,599.955711,225.375526,575.217182,333.653905,-0.279699,0.318605,0.720300,0.428427,9.152304,210.542427,0.551168,13.233824,441.005095,-0.111775
170,VBPOOPFI9E,13,0.2,0.010,0.0010,100,3.894319,0.038082,8.251228,0.285054,2.294159,0.100303,138.361733,8.950949,69.729092,3.648570,0.600288,0.021983,0.060050,0.009765,8.464398,0.189494,5.885984,0.240626,222.247240,10.734424,248.126032,15.611168,0.470232,0.016849,0.112188,0.008699,10.741285,0.397947,3.824706,0.250248,271.018307,21.723134,218.306379,31.719268,0.582983,0.022402,0.069468,0.013964,12.212815,1.513007,4.247029,0.707139,299.001789,75.823487,197.816911,57.605714,-0.001295,0.307930,0.249053,0.064459,11.253036,1.819916,8.660282,1.687628,424.057786,136.557716,629.576392,177.914503,-0.054332,0.287102,0.402061,0.460100,16.235623,2.899161,6.714920,1.563799,599.955711,225.375526,575.217182,333.653905,-0.279699,0.318605,0.720300,0.428427,9.152304,210.542427,0.551168,13.233824,441.005095,-0.111775
57,3AQKIRU4A0,13,0.2,0.010,0.0100,5000,192.408436,2.175947,8.289856,0.265004,2.319593,0.092685,139.252614,8.501136,70.637471,3.489645,0.598213,0.021442,0.060118,0.010378,8.494452,0.208733,5.919307,0.256241,223.775113,11.773403,250.878836,17.690750,0.468714,0.015390,0.111220,0.009977,10.779586,0.395190,3.835227,0.245073,271.689865,21.485872,218.848128,30.816975,0.581634,0.021981,0.069050,0.014424,12.245466,1.455029,4.298207,0.722572,299.473447,73.567691,199.628404,57.437439,-0.000613,0.294193,0.252370,0.059845,11.272031,1.809388,8.701695,1.745022,428.804102,136.144480,639.965292,179.961009,-0.060501,0.287464,0.408040,0.485301,16.211818,2.849782,6.703062,1.546566,598.866920,224.435392,573.984172,338.328978,-0.277862,0.308440,0.719454,0.429978,9.187965,211.572531,0.549520,13.243105,442.381490,-0.112992
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
82,PK2VTNQBKX,13,0.4,0.001,0.0001,500,46.678884,0.401129,3.516582,0.183804,0.912946,0.058096,34.231741,4.026586,15.202195,1.686653,0.899065,0.010729,0.023946,0.004332,3.986557,0.092097,2.655801,0.092787,54.832996,2.493795,58.225587,3.610438,0.842556,0.008426,0.050160,0.003498,4.353049,0.175771,1.581732,0.114109,55.984930,5.165746,35.137570,5.131413,0.907246,0.006922,0.026795,0.002864,12.045368,1.918981,4.106219,1.138385,310.496704,102.489664,200.360276,95.678836,-0.060809,0.370158,0.328555,0.160978,11.829419,1.951243,9.204999,2.192968,486.174030,164.479804,692.713984,211.006946,-0.325027,0.565120,0.841406,1.243149,16.067000,3.063690,6.809571,1.957659,624.459475,233.147318,586.477684,293.294842,-0.391491,0.454744,0.930473,0.800992,3.952063,48.349889,0.882955,13.313929,473.710070,-0.259109
78,DH6NMTT4QD,15,0.4,0.001,0.0001,500,47.428289,0.455100,3.365680,0.185680,0.848701,0.054971,32.493325,3.911436,13.939584,1.595703,0.903556,0.010582,0.023517,0.004239,3.803800,0.095862,2.520104,0.085418,51.477933,2.408428,54.377278,3.257447,0.852163,0.008899,0.045644,0.003963,4.169051,0.180632,1.508720,0.112409,53.408409,5.026840,34.115075,5.008033,0.911803,0.006689,0.025141,0.002794,12.029611,1.921999,4.100400,1.141122,309.863620,102.187249,199.386790,95.756341,-0.058747,0.368015,0.326959,0.161565,11.880577,1.970187,9.266832,2.242121,489.352679,166.474070,697.283127,215.331509,-0.334414,0.573837,0.849987,1.263437,16.077592,3.073498,6.820155,1.941787,624.841761,234.011873,585.981420,294.252902,-0.392921,0.457000,0.931540,0.810339,3.779510,45.793222,0.889174,13.329260,474.686020,-0.262027
134,FQQANJZDJG,13,0.4,0.001,0.0001,10000,930.732883,6.949876,3.514335,0.187492,0.912046,0.059090,34.211284,4.072645,15.192330,1.704906,0.899123,0.010857,0.023922,0.004350,3.983875,0.092301,2.653616,0.092405,54.796243,2.470503,58.192017,3.577782,0.842716,0.008607,0.050043,0.003594,4.349459,0.179234,1.579067,0.114907,55.938509,5.209664,35.096497,5.156241,0.907298,0.006958,0.026802,0.002860,12.058678,1.925365,4.116623,1.142724,311.451157,102.955557,201.283657,96.863781,-0.064063,0.371050,0.331912,0.166533,11.847830,1.961243,9.230890,2.210143,487.906292,165.334187,696.109359,212.988969,-0.330233,0.574022,0.846870,1.262973,16.079669,3.070083,6.815911,1.937637,625.098772,233.817421,585.714571,293.675226,-0.394465,0.457622,0.932132,0.811695,3.949223,48.315346,0.883045,13.328726,474.818740,-0.262920
176,78ONUWIHBJ,15,0.4,0.001,0.0010,5000,476.387517,4.343910,3.365635,0.183445,0.848346,0.054062,32.480457,3.904921,13.925698,1.584404,0.903581,0.010574,0.023474,0.004277,3.804636,0.096678,2.521193,0.087501,51.516757,2.430456,54.461191,3.343455,0.852161,0.008880,0.045617,0.003980,4.168858,0.180686,1.507586,0.112436,53.372644,5.064257,34.035666,5.087669,0.911800,0.006669,0.025202,0.002738,12.037501,1.933229,4.099544,1.146580,310.721046,103.022720,200.128966,96.990812,-0.062926,0.372191,0.331837,0.167255,11.880261,1.969904,9.262287,2.233737,489.803371,166.081486,698.052715,214.230656,-0.335839,0.575157,0.850798,1.268545,16.082885,3.071456,6.817088,1.934146,624.965057,233.707855,585.362183,292.980456,-0.394259,0.459149,0.933119,0.811796,3.779710,45.789953,0.889181,13.333549,475.163158,-0.264341


In [7]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'max_depth': 15, 'max_features': 0.2, 'min_samples_leaf': 0.01, 'min_samples_split': 0.01, 'n_estimators': 100}


## Best model

In [8]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_train_{}.npy'.format(DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_test_{}.npy'.format(DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_train_{}.npy'.format(DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_test_{}.npy'.format(DATA_ID)))

In [10]:
# Setup the model with the best parameters
model = RandomForestRegressor(**best_params, random_state=0, n_jobs=-1, verbose=1)

model.fit(X_train, Y_train)

# Save the model
dump(model, os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_best_model_{}_{}.joblib'.format(HS_DATE, DATA_ID))) 

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    1.9s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:    5.0s finished


['../../../results/0003_11042021/RF_12042021/RF_best_model_12042021_0003_11042021.joblib']

In [11]:
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

results = {
    'Train': {
        'MAE': mean_absolute_error(Y_train, train_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_train, train_preds, multioutput='raw_values'),
        'R2': r2_score(Y_train, train_preds, multioutput='raw_values')
    },
    'Test': {
        'MAE': mean_absolute_error(Y_test, test_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_test, test_preds, multioutput='raw_values'),
        'R2': r2_score(Y_test, test_preds, multioutput='raw_values')
    }       
    
}

# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy', 'Fz']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done 100 out of 100 | elapsed:    0.5s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done 100 out of 100 | elapsed:    0.1s finished


Train Fx MAE: 8.6461 ± 2.4048
Train Fx MSE: 151.6214 ± 77.8937
Train Fx R2: 0.5655 ± 0.0612
Train Fy MAE: 8.7525 ± 6.0748
Train Fy MSE: 238.7282 ± 265.7849
Train Fy R2: 0.4389 ± 0.1117
Train Fz MAE: 11.2891 ± 3.9985
Train Fz MSE: 300.2048 ± 246.7115
Train Fz R2: 0.5458 ± 0.0617
Test Fx MAE: 13.0032 ± 5.0952
Test Fx MSE: 311.5999 ± 249.7830
Test Fx R2: 0.3473 ± 0.2863
Test Fy MAE: 10.7406 ± 7.7562
Test Fy MSE: 378.7206 ± 481.6318
Test Fy R2: 0.3404 ± 0.1027
Test Fz MAE: 19.0024 ± 6.5255
Test Fz MSE: 687.2684 ± 593.5296
Test Fz R2: 0.3611 ± 0.2479


In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()