In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
from joblib import dump, load
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.ensemble import RandomForestRegressor

## Configuration

In [2]:
# Number of force cells in the robotic leg
N_CELLS = 8

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0006_14062021'
# Hyperparameters search date
HS_DATE = '14062021'
# Number of folds in cross-validation
CV = 6

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0006_14062021


## Hyperparameters seach analysis

In [7]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_{}_*.json'.format(HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 7


In [8]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Valid_Fz_MAE_mean__mean,Valid_Fz_MAE_mean__std,Valid_Fz_MAE_std__mean,Valid_Fz_MAE_std__std,Valid_Fz_MSE_mean__mean,Valid_Fz_MSE_mean__std,Valid_Fz_MSE_std__mean,Valid_Fz_MSE_std__std,Valid_Fz_R2_mean__mean,Valid_Fz_R2_mean__std,Valid_Fz_R2_std__mean,Valid_Fz_R2_std__std
0,CHTMFXYETA,15,0.2,0.001,0.0001,1000,261.464769,1.351822,8.555364,0.090836,3.067386,0.096559,145.878693,3.326182,109.517486,4.924489,0.525717,0.004452,0.105119,0.00441,6.918961,0.041594,3.95931,0.097422,130.325263,2.608796,140.427604,6.231722,0.457929,0.007137,0.073389,0.004964,14.06467,0.179834,5.232132,0.140406,384.106785,9.696091,298.369733,9.975728,0.45358,0.007517,0.146701,0.007238,9.669755,0.41495,3.671564,0.55647,186.800359,23.699804,150.536908,39.378843,0.383001,0.033911,0.146508,0.028907,7.731048,0.238702,4.601346,0.500024,172.930003,30.936871,210.179396,71.084861,0.312621,0.045341,0.104479,0.028481,15.886053,0.839656,6.051676,0.651026,491.222121,60.761519,396.140546,77.534941,0.264666,0.057232,0.231468,0.063318
1,FOJGVO06K4,5,0.2,0.001,0.001,5000,735.930935,33.738035,9.959396,0.123579,3.547758,0.122011,199.526586,5.873323,147.515305,7.778672,0.352752,0.006172,0.111952,0.0021,8.22761,0.079129,4.8694,0.132759,199.554283,7.420781,226.629534,15.865715,0.241237,0.006941,0.068138,0.003954,16.477083,0.211616,6.292219,0.167958,530.543698,12.755342,452.260356,17.313801,0.266962,0.005017,0.15225,0.005313,10.314933,0.490094,3.846167,0.583364,214.749505,29.147854,167.935911,48.309832,0.292414,0.029296,0.132283,0.020721,8.443104,0.365394,5.108554,0.542932,212.322671,40.339561,264.107949,82.164286,0.191325,0.025097,0.085088,0.015516,17.029729,0.952882,6.647709,0.800005,569.874965,67.6099,498.245716,105.786553,0.181323,0.044848,0.200258,0.049567
2,M61UEBMW2W,10,0.1,0.1,0.01,100,7.249432,0.092028,10.85437,0.130196,3.63308,0.139923,241.119979,5.790428,163.535988,9.109689,0.208573,0.005018,0.079591,0.002483,8.935382,0.102916,5.370592,0.135708,241.086046,8.228274,271.5329,15.316895,0.104102,0.003308,0.049151,0.002375,17.785694,0.212771,6.92373,0.161816,639.67114,14.081926,591.121489,23.741472,0.146416,0.004924,0.107623,0.004713,11.018534,0.583978,3.868267,0.575034,248.108064,29.849554,179.007857,49.700323,0.170449,0.021629,0.098268,0.020299,9.027627,0.49312,5.528498,0.559732,245.405661,42.221195,301.14344,71.804138,0.075586,0.017279,0.056788,0.013749,18.031296,0.993539,7.137085,0.789513,654.69272,71.028524,607.099473,126.027864,0.088898,0.039515,0.15518,0.047381
3,TQG9LEJ24Q,15,0.05,0.01,0.0001,5000,900.22464,14.420996,9.777916,0.106226,3.486004,0.110139,193.231488,4.832284,141.876977,6.919773,0.373757,0.006025,0.10768,0.002527,8.009687,0.060754,4.74784,0.118329,190.35631,5.707914,215.368716,12.753836,0.278955,0.007245,0.052385,0.004605,16.231817,0.20277,6.24199,0.171828,526.520955,12.854846,462.820445,17.812369,0.280827,0.007221,0.150585,0.005067,10.205625,0.478863,3.822026,0.576831,210.8217,27.813037,165.049691,46.568051,0.305601,0.02681,0.132571,0.022245,8.286607,0.324493,5.021322,0.510792,206.692158,38.239849,257.580558,78.495509,0.217202,0.030452,0.081066,0.01922,16.902587,0.934523,6.611143,0.790886,569.706727,68.044943,504.532422,103.014624,0.183179,0.047998,0.208161,0.050532
4,UM5WLSSGTB,13,0.1,0.1,0.0001,5000,364.650932,3.787503,10.907341,0.14121,3.623359,0.136102,244.297986,5.903655,164.255426,8.895655,0.197869,0.004395,0.072849,0.002228,8.907411,0.098929,5.33928,0.131875,240.590771,8.149099,270.799428,15.225951,0.105068,0.003346,0.050733,0.002283,17.820575,0.220355,6.953428,0.150664,644.588642,14.461177,601.537502,23.413891,0.142101,0.004565,0.104456,0.003878,11.07037,0.585445,3.863062,0.565655,251.247999,29.744516,179.935656,49.24995,0.159325,0.021726,0.091857,0.021262,9.000715,0.488784,5.497791,0.558584,245.010434,42.351841,300.740299,72.025909,0.076297,0.017387,0.057961,0.013442,18.062377,0.992469,7.164595,0.780508,659.361554,71.177406,617.293776,127.312276,0.08513,0.039283,0.151648,0.046192
5,W7C4X8XELO,10,0.01,0.01,0.0001,10000,1790.082741,11.482721,9.779844,0.106425,3.487188,0.110402,193.28868,4.83212,141.942986,6.932739,0.37359,0.005911,0.10775,0.002482,8.012475,0.060213,4.749639,0.117979,190.372586,5.672989,215.383074,12.7279,0.278752,0.007177,0.052576,0.004528,16.235743,0.202808,6.243373,0.1712,526.659871,12.812605,462.82691,17.702541,0.280599,0.007134,0.150662,0.005034,10.206218,0.478787,3.822921,0.57694,210.827712,27.845892,165.052513,46.577055,0.305592,0.026744,0.1326,0.022187,8.287461,0.323958,5.021355,0.510701,206.627114,38.197952,257.449847,78.472029,0.217207,0.030634,0.081142,0.019223,16.902625,0.936547,6.611363,0.792644,569.675008,68.284624,504.490983,103.346297,0.18323,0.048172,0.208228,0.050531
6,WC89T5JTL8,10,0.2,0.1,0.01,5000,452.330458,14.601387,10.907341,0.14121,3.623359,0.136102,244.297986,5.903655,164.255426,8.895655,0.197869,0.004395,0.072849,0.002228,8.907411,0.098929,5.33928,0.131875,240.590771,8.149099,270.799428,15.225951,0.105068,0.003346,0.050733,0.002283,17.820575,0.220355,6.953428,0.150664,644.588642,14.461177,601.537502,23.413891,0.142101,0.004565,0.104456,0.003878,11.07037,0.585445,3.863062,0.565655,251.247999,29.744516,179.935656,49.24995,0.159325,0.021726,0.091857,0.021262,9.000715,0.488784,5.497791,0.558584,245.010434,42.351841,300.740299,72.025909,0.076297,0.017387,0.057961,0.013442,18.062377,0.992469,7.164595,0.780508,659.361554,71.177406,617.293776,127.312276,0.08513,0.039283,0.151648,0.046192


In [9]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy', 'Fz']]].mean(axis=1)

In [10]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Valid_Fz_MAE_mean__mean,Valid_Fz_MAE_mean__std,Valid_Fz_MAE_std__mean,Valid_Fz_MAE_std__std,Valid_Fz_MSE_mean__mean,Valid_Fz_MSE_mean__std,Valid_Fz_MSE_std__mean,Valid_Fz_MSE_std__std,Valid_Fz_R2_mean__mean,Valid_Fz_R2_mean__std,Valid_Fz_R2_std__mean,Valid_Fz_R2_std__std,Train_MAE,Train_MSE,Train_R2,Valid_MAE,Valid_MSE,Valid_R2
0,CHTMFXYETA,15,0.2,0.001,0.0001,1000,261.464769,1.351822,8.555364,0.090836,3.067386,0.096559,145.878693,3.326182,109.517486,4.924489,0.525717,0.004452,0.105119,0.00441,6.918961,0.041594,3.95931,0.097422,130.325263,2.608796,140.427604,6.231722,0.457929,0.007137,0.073389,0.004964,14.06467,0.179834,5.232132,0.140406,384.106785,9.696091,298.369733,9.975728,0.45358,0.007517,0.146701,0.007238,9.669755,0.41495,3.671564,0.55647,186.800359,23.699804,150.536908,39.378843,0.383001,0.033911,0.146508,0.028907,7.731048,0.238702,4.601346,0.500024,172.930003,30.936871,210.179396,71.084861,0.312621,0.045341,0.104479,0.028481,15.886053,0.839656,6.051676,0.651026,491.222121,60.761519,396.140546,77.534941,0.264666,0.057232,0.231468,0.063318,9.846332,220.10358,0.479075,11.095619,283.650828,0.320096
5,W7C4X8XELO,10,0.01,0.01,0.0001,10000,1790.082741,11.482721,9.779844,0.106425,3.487188,0.110402,193.28868,4.83212,141.942986,6.932739,0.37359,0.005911,0.10775,0.002482,8.012475,0.060213,4.749639,0.117979,190.372586,5.672989,215.383074,12.7279,0.278752,0.007177,0.052576,0.004528,16.235743,0.202808,6.243373,0.1712,526.659871,12.812605,462.82691,17.702541,0.280599,0.007134,0.150662,0.005034,10.206218,0.478787,3.822921,0.57694,210.827712,27.845892,165.052513,46.577055,0.305592,0.026744,0.1326,0.022187,8.287461,0.323958,5.021355,0.510701,206.627114,38.197952,257.449847,78.472029,0.217207,0.030634,0.081142,0.019223,16.902625,0.936547,6.611363,0.792644,569.675008,68.284624,504.490983,103.346297,0.18323,0.048172,0.208228,0.050531,11.342687,303.440379,0.31098,11.798768,329.043278,0.235343
3,TQG9LEJ24Q,15,0.05,0.01,0.0001,5000,900.22464,14.420996,9.777916,0.106226,3.486004,0.110139,193.231488,4.832284,141.876977,6.919773,0.373757,0.006025,0.10768,0.002527,8.009687,0.060754,4.74784,0.118329,190.35631,5.707914,215.368716,12.753836,0.278955,0.007245,0.052385,0.004605,16.231817,0.20277,6.24199,0.171828,526.520955,12.854846,462.820445,17.812369,0.280827,0.007221,0.150585,0.005067,10.205625,0.478863,3.822026,0.576831,210.8217,27.813037,165.049691,46.568051,0.305601,0.02681,0.132571,0.022245,8.286607,0.324493,5.021322,0.510792,206.692158,38.239849,257.580558,78.495509,0.217202,0.030452,0.081066,0.01922,16.902587,0.934523,6.611143,0.790886,569.706727,68.044943,504.532422,103.014624,0.183179,0.047998,0.208161,0.050532,11.339807,303.369584,0.31118,11.798273,329.073529,0.235328
1,FOJGVO06K4,5,0.2,0.001,0.001,5000,735.930935,33.738035,9.959396,0.123579,3.547758,0.122011,199.526586,5.873323,147.515305,7.778672,0.352752,0.006172,0.111952,0.0021,8.22761,0.079129,4.8694,0.132759,199.554283,7.420781,226.629534,15.865715,0.241237,0.006941,0.068138,0.003954,16.477083,0.211616,6.292219,0.167958,530.543698,12.755342,452.260356,17.313801,0.266962,0.005017,0.15225,0.005313,10.314933,0.490094,3.846167,0.583364,214.749505,29.147854,167.935911,48.309832,0.292414,0.029296,0.132283,0.020721,8.443104,0.365394,5.108554,0.542932,212.322671,40.339561,264.107949,82.164286,0.191325,0.025097,0.085088,0.015516,17.029729,0.952882,6.647709,0.800005,569.874965,67.6099,498.245716,105.786553,0.181323,0.044848,0.200258,0.049567,11.554696,309.874856,0.286984,11.929255,332.315714,0.221687
2,M61UEBMW2W,10,0.1,0.1,0.01,100,7.249432,0.092028,10.85437,0.130196,3.63308,0.139923,241.119979,5.790428,163.535988,9.109689,0.208573,0.005018,0.079591,0.002483,8.935382,0.102916,5.370592,0.135708,241.086046,8.228274,271.5329,15.316895,0.104102,0.003308,0.049151,0.002375,17.785694,0.212771,6.92373,0.161816,639.67114,14.081926,591.121489,23.741472,0.146416,0.004924,0.107623,0.004713,11.018534,0.583978,3.868267,0.575034,248.108064,29.849554,179.007857,49.700323,0.170449,0.021629,0.098268,0.020299,9.027627,0.49312,5.528498,0.559732,245.405661,42.221195,301.14344,71.804138,0.075586,0.017279,0.056788,0.013749,18.031296,0.993539,7.137085,0.789513,654.69272,71.028524,607.099473,126.027864,0.088898,0.039515,0.15518,0.047381,12.525149,373.959055,0.15303,12.692486,382.735482,0.111644
4,UM5WLSSGTB,13,0.1,0.1,0.0001,5000,364.650932,3.787503,10.907341,0.14121,3.623359,0.136102,244.297986,5.903655,164.255426,8.895655,0.197869,0.004395,0.072849,0.002228,8.907411,0.098929,5.33928,0.131875,240.590771,8.149099,270.799428,15.225951,0.105068,0.003346,0.050733,0.002283,17.820575,0.220355,6.953428,0.150664,644.588642,14.461177,601.537502,23.413891,0.142101,0.004565,0.104456,0.003878,11.07037,0.585445,3.863062,0.565655,251.247999,29.744516,179.935656,49.24995,0.159325,0.021726,0.091857,0.021262,9.000715,0.488784,5.497791,0.558584,245.010434,42.351841,300.740299,72.025909,0.076297,0.017387,0.057961,0.013442,18.062377,0.992469,7.164595,0.780508,659.361554,71.177406,617.293776,127.312276,0.08513,0.039283,0.151648,0.046192,12.545109,376.492466,0.148346,12.711154,385.206662,0.106917
6,WC89T5JTL8,10,0.2,0.1,0.01,5000,452.330458,14.601387,10.907341,0.14121,3.623359,0.136102,244.297986,5.903655,164.255426,8.895655,0.197869,0.004395,0.072849,0.002228,8.907411,0.098929,5.33928,0.131875,240.590771,8.149099,270.799428,15.225951,0.105068,0.003346,0.050733,0.002283,17.820575,0.220355,6.953428,0.150664,644.588642,14.461177,601.537502,23.413891,0.142101,0.004565,0.104456,0.003878,11.07037,0.585445,3.863062,0.565655,251.247999,29.744516,179.935656,49.24995,0.159325,0.021726,0.091857,0.021262,9.000715,0.488784,5.497791,0.558584,245.010434,42.351841,300.740299,72.025909,0.076297,0.017387,0.057961,0.013442,18.062377,0.992469,7.164595,0.780508,659.361554,71.177406,617.293776,127.312276,0.08513,0.039283,0.151648,0.046192,12.545109,376.492466,0.148346,12.711154,385.206662,0.106917


In [11]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'max_depth': 15, 'max_features': 0.2, 'min_samples_leaf': 0.001, 'min_samples_split': 0.0001, 'n_estimators': 1000}


## Best model

In [8]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_train_{}.npy'.format(DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_test_{}.npy'.format(DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_train_{}.npy'.format(DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_test_{}.npy'.format(DATA_ID)))

In [10]:
# Setup the model with the best parameters
model = RandomForestRegressor(**best_params, random_state=0, n_jobs=-1, verbose=1)

model.fit(X_train, Y_train)

# Save the model
dump(model, os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_best_model_{}_{}.joblib'.format(HS_DATE, DATA_ID))) 

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    1.9s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:    5.0s finished


['../../../results/0003_11042021/RF_12042021/RF_best_model_12042021_0003_11042021.joblib']

In [11]:
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

results = {
    'Train': {
        'MAE': mean_absolute_error(Y_train, train_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_train, train_preds, multioutput='raw_values'),
        'R2': r2_score(Y_train, train_preds, multioutput='raw_values')
    },
    'Test': {
        'MAE': mean_absolute_error(Y_test, test_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_test, test_preds, multioutput='raw_values'),
        'R2': r2_score(Y_test, test_preds, multioutput='raw_values')
    }       
    
}

# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy', 'Fz']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done 100 out of 100 | elapsed:    0.5s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done 100 out of 100 | elapsed:    0.1s finished


Train Fx MAE: 8.6461 ± 2.4048
Train Fx MSE: 151.6214 ± 77.8937
Train Fx R2: 0.5655 ± 0.0612
Train Fy MAE: 8.7525 ± 6.0748
Train Fy MSE: 238.7282 ± 265.7849
Train Fy R2: 0.4389 ± 0.1117
Train Fz MAE: 11.2891 ± 3.9985
Train Fz MSE: 300.2048 ± 246.7115
Train Fz R2: 0.5458 ± 0.0617
Test Fx MAE: 13.0032 ± 5.0952
Test Fx MSE: 311.5999 ± 249.7830
Test Fx R2: 0.3473 ± 0.2863
Test Fy MAE: 10.7406 ± 7.7562
Test Fy MSE: 378.7206 ± 481.6318
Test Fy R2: 0.3404 ± 0.1027
Test Fz MAE: 19.0024 ± 6.5255
Test Fz MSE: 687.2684 ± 593.5296
Test Fz R2: 0.3611 ± 0.2479


In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()