In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.ensemble import RandomForestRegressor

## Configuration

In [2]:
# Number of force cells in the robotic leg
N_CELLS = 8

# Path where the results are stored
RESULTS_PATH = '../../../results'
# ID of the training and test data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0003_11042021'
# Hyperparameters search date
HS_DATE = '12042021'
# Number of folds in cross-validation
CV = 6

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0003_11042021


## Hyperparameters seach analysis

In [3]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_{}_*.json'.format(HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of result files: 7


In [4]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Test_Fx_MAE_mean__mean,Test_Fx_MAE_mean__std,Test_Fx_MAE_std__mean,Test_Fx_MAE_std__std,Test_Fx_MSE_mean__mean,Test_Fx_MSE_mean__std,Test_Fx_MSE_std__mean,Test_Fx_MSE_std__std,Test_Fx_R2_mean__mean,Test_Fx_R2_mean__std,Test_Fx_R2_std__mean,Test_Fx_R2_std__std,Test_Fy_MAE_mean__mean,Test_Fy_MAE_mean__std,Test_Fy_MAE_std__mean,Test_Fy_MAE_std__std,Test_Fy_MSE_mean__mean,Test_Fy_MSE_mean__std,Test_Fy_MSE_std__mean,Test_Fy_MSE_std__std,Test_Fy_R2_mean__mean,Test_Fy_R2_mean__std,Test_Fy_R2_std__mean,Test_Fy_R2_std__std,Test_Fz_MAE_mean__mean,Test_Fz_MAE_mean__std,Test_Fz_MAE_std__mean,Test_Fz_MAE_std__std,Test_Fz_MSE_mean__mean,Test_Fz_MSE_mean__std,Test_Fz_MSE_std__mean,Test_Fz_MSE_std__std,Test_Fz_R2_mean__mean,Test_Fz_R2_mean__std,Test_Fz_R2_std__mean,Test_Fz_R2_std__std
0,0J98DBZLWB,5,0.3,0.01,0.1,100,5.794,0.188623,10.524595,0.254087,3.234118,0.115952,215.6064,11.578915,124.278352,4.297752,0.374334,0.033817,0.062223,0.004895,10.125942,0.476087,7.11806,0.418942,323.415712,35.123979,379.592098,49.927343,0.266373,0.01895,0.088513,0.014662,13.724185,0.437286,4.851363,0.206986,435.51217,27.089149,372.665154,31.054091,0.330505,0.027823,0.087846,0.012346,13.608913,1.371315,5.399668,0.525332,372.471861,63.627794,261.324848,28.568873,-0.003642,0.238723,0.240963,0.128826,12.008707,3.069492,9.102462,2.903734,511.465863,284.772698,771.414855,491.432828,-0.042728,0.120944,0.211782,0.088686,17.507047,2.235787,7.344584,1.066767,658.664467,155.348706,604.734946,183.338186,-0.127095,0.175042,0.350159,0.115422
1,2IW273I4T7,5,0.3,0.01,0.2,10000,393.117657,12.298619,11.244871,0.271398,3.544623,0.172447,247.260898,12.958611,146.46824,5.903505,0.288164,0.035454,0.061329,0.007681,10.573115,0.526584,7.397988,0.456945,358.352982,41.498873,428.44489,60.36262,0.192326,0.020062,0.086655,0.018301,14.768185,0.407943,5.371083,0.237445,510.250548,24.576719,468.048184,33.203192,0.238627,0.026171,0.077279,0.009882,13.80174,1.166627,5.678779,0.550038,378.935846,53.620489,275.849474,26.540841,0.001591,0.182689,0.227493,0.09725,12.1129,2.859342,9.164995,2.793323,510.721163,272.645053,775.106832,497.536529,-0.060864,0.085815,0.210944,0.092336,17.468819,1.961279,7.110134,0.826265,638.486086,134.403324,571.989374,172.620562,-0.132244,0.122642,0.407105,0.135357
2,3FV2UQB8L7,5,0.3,0.01,0.2,1000,83.373339,3.257934,11.24579,0.269816,3.549518,0.171756,247.322266,12.930461,146.757598,5.905896,0.288215,0.035214,0.061395,0.007681,10.567766,0.523876,7.395546,0.455501,357.935681,41.280046,427.897296,60.144316,0.193281,0.020234,0.087151,0.018711,14.775625,0.409892,5.372031,0.233998,510.459551,24.347265,468.060718,33.161035,0.238151,0.02621,0.077765,0.009591,13.808141,1.160099,5.694871,0.544838,379.320908,53.596304,276.570847,25.990296,0.001044,0.182792,0.228814,0.098447,12.113012,2.871817,9.176388,2.814672,511.114207,273.986086,776.755217,501.103462,-0.059927,0.087595,0.210795,0.09216,17.479462,1.959982,7.10452,0.808878,639.015347,135.309731,571.832532,173.027601,-0.133758,0.123198,0.409411,0.136214
3,C4B6JI4IVS,5,0.3,0.01,0.1,500,27.732685,1.953594,10.446101,0.25466,3.219161,0.110539,212.752572,11.330539,122.568164,4.576076,0.382867,0.031964,0.063623,0.003748,10.057437,0.481708,7.045409,0.417817,317.552607,34.884449,370.777567,48.53806,0.272793,0.018499,0.091679,0.014517,13.629574,0.406168,4.848831,0.20736,430.679856,24.530712,368.935807,27.818196,0.339036,0.024746,0.087035,0.011018,13.551268,1.373641,5.405589,0.549479,371.301519,65.025369,262.155196,31.721316,0.006628,0.221516,0.230702,0.102706,11.889677,3.02086,8.963393,2.793231,501.066007,274.737435,753.618556,464.244618,-0.030869,0.119464,0.197648,0.074375,17.413047,2.201543,7.354422,1.062491,654.461704,155.137638,602.817214,186.50559,-0.113945,0.166752,0.342951,0.120862
4,IULDYWVS3D,5,0.3,0.01,0.2,500,41.467785,1.108158,11.237776,0.266124,3.544505,0.176226,246.969746,12.761921,146.4921,6.030585,0.289056,0.034802,0.061278,0.007215,10.569171,0.518158,7.399818,0.45173,358.193331,40.933486,428.571002,59.629306,0.193485,0.019545,0.086632,0.018485,14.777122,0.402019,5.368534,0.233122,510.130815,23.809942,467.015455,33.472983,0.238111,0.025491,0.077792,0.009214,13.804031,1.154316,5.69109,0.536724,379.051345,53.616194,276.346762,25.709488,0.001642,0.182047,0.229164,0.098084,12.115957,2.874342,9.182035,2.816639,511.027404,273.864962,776.457197,500.805032,-0.060057,0.088375,0.211236,0.0936,17.482915,1.959826,7.103005,0.805775,639.038626,135.326954,571.657402,172.848072,-0.134351,0.123237,0.410682,0.137242
5,UBRFZDSOSJ,5,0.3,0.01,0.2,5000,320.885,114.149445,11.243648,0.271575,3.543637,0.171155,247.220749,12.954255,146.432939,5.86967,0.288286,0.035415,0.061289,0.007599,10.572245,0.525498,7.397747,0.455431,358.303285,41.382958,428.384702,60.23068,0.192517,0.020135,0.0867,0.018472,14.766767,0.407935,5.371068,0.236288,510.256821,24.468149,468.152879,33.254891,0.238695,0.026119,0.077317,0.009808,13.804771,1.166555,5.682405,0.551952,379.1214,53.653037,276.047301,26.572662,0.001242,0.182508,0.22756,0.097059,12.114901,2.861287,9.168947,2.796033,511.036299,273.057577,775.845497,498.472195,-0.060938,0.086166,0.210989,0.092627,17.473644,1.964906,7.111498,0.825385,639.058623,134.910508,572.632325,173.18356,-0.133026,0.122553,0.407491,0.135084
6,ZNX34F2QWQ,5,0.3,0.01,0.2,100,8.015768,0.330225,11.249076,0.23771,3.530468,0.155693,247.230701,11.731414,146.107766,4.799538,0.287538,0.033465,0.060849,0.007621,10.570429,0.501241,7.39124,0.437075,357.830158,39.855769,427.659839,57.65358,0.193079,0.01736,0.086329,0.018157,14.77268,0.384427,5.346139,0.225043,509.97001,22.228232,466.551909,33.180924,0.23786,0.024157,0.07758,0.00954,13.786431,1.162083,5.654476,0.567133,377.87621,52.87921,274.27008,26.189279,0.002668,0.185789,0.227153,0.101817,12.128288,2.893616,9.193623,2.833978,512.6076,275.351835,779.82636,502.361208,-0.060138,0.090962,0.211055,0.094525,17.477146,1.96223,7.102168,0.798254,639.141268,134.134625,572.873847,170.701886,-0.132574,0.126708,0.405841,0.135922


In [5]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Test']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy', 'Fz']]].mean(axis=1)

In [6]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Test_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Train_Fz_MAE_mean__mean,Train_Fz_MAE_mean__std,Train_Fz_MAE_std__mean,Train_Fz_MAE_std__std,Train_Fz_MSE_mean__mean,Train_Fz_MSE_mean__std,Train_Fz_MSE_std__mean,Train_Fz_MSE_std__std,Train_Fz_R2_mean__mean,Train_Fz_R2_mean__std,Train_Fz_R2_std__mean,Train_Fz_R2_std__std,Test_Fx_MAE_mean__mean,Test_Fx_MAE_mean__std,Test_Fx_MAE_std__mean,Test_Fx_MAE_std__std,Test_Fx_MSE_mean__mean,Test_Fx_MSE_mean__std,Test_Fx_MSE_std__mean,Test_Fx_MSE_std__std,Test_Fx_R2_mean__mean,Test_Fx_R2_mean__std,Test_Fx_R2_std__mean,Test_Fx_R2_std__std,Test_Fy_MAE_mean__mean,Test_Fy_MAE_mean__std,Test_Fy_MAE_std__mean,Test_Fy_MAE_std__std,Test_Fy_MSE_mean__mean,Test_Fy_MSE_mean__std,Test_Fy_MSE_std__mean,Test_Fy_MSE_std__std,Test_Fy_R2_mean__mean,Test_Fy_R2_mean__std,Test_Fy_R2_std__mean,Test_Fy_R2_std__std,Test_Fz_MAE_mean__mean,Test_Fz_MAE_mean__std,Test_Fz_MAE_std__mean,Test_Fz_MAE_std__std,Test_Fz_MSE_mean__mean,Test_Fz_MSE_mean__std,Test_Fz_MSE_std__mean,Test_Fz_MSE_std__std,Test_Fz_R2_mean__mean,Test_Fz_R2_mean__std,Test_Fz_R2_std__mean,Test_Fz_R2_std__std,Train_MAE,Train_MSE,Train_R2,Test_MAE,Test_MSE,Test_R2
3,C4B6JI4IVS,5,0.3,0.01,0.1,500,27.732685,1.953594,10.446101,0.25466,3.219161,0.110539,212.752572,11.330539,122.568164,4.576076,0.382867,0.031964,0.063623,0.003748,10.057437,0.481708,7.045409,0.417817,317.552607,34.884449,370.777567,48.53806,0.272793,0.018499,0.091679,0.014517,13.629574,0.406168,4.848831,0.20736,430.679856,24.530712,368.935807,27.818196,0.339036,0.024746,0.087035,0.011018,13.551268,1.373641,5.405589,0.549479,371.301519,65.025369,262.155196,31.721316,0.006628,0.221516,0.230702,0.102706,11.889677,3.02086,8.963393,2.793231,501.066007,274.737435,753.618556,464.244618,-0.030869,0.119464,0.197648,0.074375,17.413047,2.201543,7.354422,1.062491,654.461704,155.137638,602.817214,186.50559,-0.113945,0.166752,0.342951,0.120862,11.377704,320.328345,0.331565,14.284664,508.943077,-0.046062
0,0J98DBZLWB,5,0.3,0.01,0.1,100,5.794,0.188623,10.524595,0.254087,3.234118,0.115952,215.6064,11.578915,124.278352,4.297752,0.374334,0.033817,0.062223,0.004895,10.125942,0.476087,7.11806,0.418942,323.415712,35.123979,379.592098,49.927343,0.266373,0.01895,0.088513,0.014662,13.724185,0.437286,4.851363,0.206986,435.51217,27.089149,372.665154,31.054091,0.330505,0.027823,0.087846,0.012346,13.608913,1.371315,5.399668,0.525332,372.471861,63.627794,261.324848,28.568873,-0.003642,0.238723,0.240963,0.128826,12.008707,3.069492,9.102462,2.903734,511.465863,284.772698,771.414855,491.432828,-0.042728,0.120944,0.211782,0.088686,17.507047,2.235787,7.344584,1.066767,658.664467,155.348706,604.734946,183.338186,-0.127095,0.175042,0.350159,0.115422,11.458241,324.844761,0.323737,14.374889,514.20073,-0.057822
6,ZNX34F2QWQ,5,0.3,0.01,0.2,100,8.015768,0.330225,11.249076,0.23771,3.530468,0.155693,247.230701,11.731414,146.107766,4.799538,0.287538,0.033465,0.060849,0.007621,10.570429,0.501241,7.39124,0.437075,357.830158,39.855769,427.659839,57.65358,0.193079,0.01736,0.086329,0.018157,14.77268,0.384427,5.346139,0.225043,509.97001,22.228232,466.551909,33.180924,0.23786,0.024157,0.07758,0.00954,13.786431,1.162083,5.654476,0.567133,377.87621,52.87921,274.27008,26.189279,0.002668,0.185789,0.227153,0.101817,12.128288,2.893616,9.193623,2.833978,512.6076,275.351835,779.82636,502.361208,-0.060138,0.090962,0.211055,0.094525,17.477146,1.96223,7.102168,0.798254,639.141268,134.134625,572.873847,170.701886,-0.132574,0.126708,0.405841,0.135922,12.197395,371.676956,0.239492,14.463955,509.875026,-0.063348
1,2IW273I4T7,5,0.3,0.01,0.2,10000,393.117657,12.298619,11.244871,0.271398,3.544623,0.172447,247.260898,12.958611,146.46824,5.903505,0.288164,0.035454,0.061329,0.007681,10.573115,0.526584,7.397988,0.456945,358.352982,41.498873,428.44489,60.36262,0.192326,0.020062,0.086655,0.018301,14.768185,0.407943,5.371083,0.237445,510.250548,24.576719,468.048184,33.203192,0.238627,0.026171,0.077279,0.009882,13.80174,1.166627,5.678779,0.550038,378.935846,53.620489,275.849474,26.540841,0.001591,0.182689,0.227493,0.09725,12.1129,2.859342,9.164995,2.793323,510.721163,272.645053,775.106832,497.536529,-0.060864,0.085815,0.210944,0.092336,17.468819,1.961279,7.110134,0.826265,638.486086,134.403324,571.989374,172.620562,-0.132244,0.122642,0.407105,0.135357,12.195391,371.954809,0.239706,14.461153,509.381032,-0.063839
2,3FV2UQB8L7,5,0.3,0.01,0.2,1000,83.373339,3.257934,11.24579,0.269816,3.549518,0.171756,247.322266,12.930461,146.757598,5.905896,0.288215,0.035214,0.061395,0.007681,10.567766,0.523876,7.395546,0.455501,357.935681,41.280046,427.897296,60.144316,0.193281,0.020234,0.087151,0.018711,14.775625,0.409892,5.372031,0.233998,510.459551,24.347265,468.060718,33.161035,0.238151,0.02621,0.077765,0.009591,13.808141,1.160099,5.694871,0.544838,379.320908,53.596304,276.570847,25.990296,0.001044,0.182792,0.228814,0.098447,12.113012,2.871817,9.176388,2.814672,511.114207,273.986086,776.755217,501.103462,-0.059927,0.087595,0.210795,0.09216,17.479462,1.959982,7.10452,0.808878,639.015347,135.309731,571.832532,173.027601,-0.133758,0.123198,0.409411,0.136214,12.196394,371.905833,0.239882,14.466872,509.816821,-0.064214
5,UBRFZDSOSJ,5,0.3,0.01,0.2,5000,320.885,114.149445,11.243648,0.271575,3.543637,0.171155,247.220749,12.954255,146.432939,5.86967,0.288286,0.035415,0.061289,0.007599,10.572245,0.525498,7.397747,0.455431,358.303285,41.382958,428.384702,60.23068,0.192517,0.020135,0.0867,0.018472,14.766767,0.407935,5.371068,0.236288,510.256821,24.468149,468.152879,33.254891,0.238695,0.026119,0.077317,0.009808,13.804771,1.166555,5.682405,0.551952,379.1214,53.653037,276.047301,26.572662,0.001242,0.182508,0.22756,0.097059,12.114901,2.861287,9.168947,2.796033,511.036299,273.057577,775.845497,498.472195,-0.060938,0.086166,0.210989,0.092627,17.473644,1.964906,7.111498,0.825385,639.058623,134.910508,572.632325,173.18356,-0.133026,0.122553,0.407491,0.135084,12.19422,371.926952,0.239833,14.464439,509.738774,-0.064241
4,IULDYWVS3D,5,0.3,0.01,0.2,500,41.467785,1.108158,11.237776,0.266124,3.544505,0.176226,246.969746,12.761921,146.4921,6.030585,0.289056,0.034802,0.061278,0.007215,10.569171,0.518158,7.399818,0.45173,358.193331,40.933486,428.571002,59.629306,0.193485,0.019545,0.086632,0.018485,14.777122,0.402019,5.368534,0.233122,510.130815,23.809942,467.015455,33.472983,0.238111,0.025491,0.077792,0.009214,13.804031,1.154316,5.69109,0.536724,379.051345,53.616194,276.346762,25.709488,0.001642,0.182047,0.229164,0.098084,12.115957,2.874342,9.182035,2.816639,511.027404,273.864962,776.457197,500.805032,-0.060057,0.088375,0.211236,0.0936,17.482915,1.959826,7.103005,0.805775,639.038626,135.326954,571.657402,172.848072,-0.134351,0.123237,0.410682,0.137242,12.19469,371.764631,0.240217,14.467635,509.705792,-0.064255


In [7]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'max_depth': 5, 'max_features': 0.3, 'min_samples_leaf': 0.01, 'min_samples_split': 0.1, 'n_estimators': 500}


## Best model

In [None]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_train_{}.npy'.format(DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_test_{}.npy'.format(DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_train_{}.npy'.format(DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_test_{}.npy'.format(DATA_ID)))

In [None]:
# Setup the model with the best parameters
model = RandomForestRegressor(**best_params, random_state=0, n_jobs=-1, verbose=1)

model.fit(X_train, Y_train)

# Save the model
dump(model, os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_best_model_{}_{}.joblib'.format(HS_DATE, DATA_ID))) 

In [None]:
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

results = {
    'Train': {
        'MAE': mean_absolute_error(Y_train, train_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_train, train_preds, multioutput='raw_values'),
        'R2': r2_score(Y_train, train_preds, multioutput='raw_values')
    },
    'Test': {
        'MAE': mean_absolute_error(Y_test, test_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_test, test_preds, multioutput='raw_values'),
        'R2': r2_score(Y_test, test_preds, multioutput='raw_values')
    }       
    
}

# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy', 'Fz']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()