In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
from joblib import dump, load
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.ensemble import RandomForestRegressor

## Configuration

In [2]:
# Number of force cells in the robotic leg
N_CELLS = 8

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0007_19072021'
# Hyperparameters search date
HS_DATE = '19072021'
# Number of folds in cross-validation
CV = 6

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0007_19072021


## Hyperparameters seach analysis

In [33]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_{}_*.json'.format(HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 477


In [34]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std
0,012ZY6CFF3,15,0.10,0.010,0.0001,100,0.455093,0.007254,8.188147,0.078719,3.339821,0.046849,155.811724,2.418438,101.145140,2.343509,0.393600,0.006439,0.139414,0.009066,5.842679,0.121483,1.625441,0.076182,116.556536,6.628633,96.281522,7.849076,0.889247,0.005252,0.038166,0.002910,8.802642,0.510228,3.642327,0.262421,176.332693,14.855311,118.651327,19.083946,0.297903,0.023087,0.149635,0.037186,6.573267,0.958029,1.987674,0.670323,139.161921,61.610775,115.989577,74.570185,0.861858,0.039361,0.044421,0.015460
1,01JWTIP2CK,13,0.05,0.001,0.0001,10000,61.193716,0.802063,5.750351,0.038321,2.364610,0.046404,76.951419,0.924645,49.235046,1.348746,0.670582,0.005235,0.120927,0.005610,2.656379,0.020933,0.620435,0.020509,22.716786,0.527829,14.897265,0.538528,0.976076,0.000586,0.007785,0.000418,7.711804,0.386912,3.174346,0.339966,130.030766,8.974161,82.313566,10.145761,0.445416,0.026265,0.158941,0.048680,3.997086,0.723181,1.144599,0.484141,52.501241,31.211960,42.034105,37.986755,0.946562,0.020183,0.018621,0.007631
2,03KEOA27I8,13,0.05,0.001,0.0010,1000,6.186068,0.103448,5.747449,0.038152,2.363168,0.046778,76.921070,0.938760,49.213528,1.389733,0.670806,0.005424,0.120734,0.005674,2.662813,0.020723,0.625407,0.020381,22.781827,0.577503,14.777465,0.557684,0.975867,0.000623,0.008152,0.000428,7.707151,0.390238,3.172002,0.338642,129.874632,9.032385,82.175658,10.006421,0.445995,0.025992,0.158985,0.048218,4.009785,0.723710,1.142511,0.469375,52.431659,30.505027,41.416118,36.601136,0.946115,0.020200,0.019131,0.008180
3,03P0PU5K2L,5,0.01,0.010,0.0010,100,0.362862,0.009496,8.745866,0.113333,3.478432,0.055000,170.349750,2.772972,106.712218,2.316136,0.331377,0.005050,0.151088,0.008547,7.873469,0.128609,2.000502,0.098507,154.093018,6.515274,114.038516,7.847105,0.846144,0.004913,0.046238,0.003671,9.098224,0.551886,3.708899,0.251073,184.404337,16.526077,122.265193,20.651903,0.263173,0.018710,0.158051,0.035094,8.379544,1.091324,2.346953,0.672896,176.453311,67.973807,134.928472,80.509654,0.816592,0.045451,0.055102,0.014803
4,06LX5K20FK,5,0.01,0.001,0.0001,10000,36.923214,0.297093,8.461554,0.118066,3.339390,0.064742,149.625267,2.912678,89.316155,2.233265,0.386130,0.005383,0.178439,0.010166,7.698590,0.153590,1.912399,0.074013,119.428774,3.991498,69.305583,2.853659,0.869829,0.004371,0.036061,0.002601,8.878085,0.556001,3.588064,0.296149,167.380189,14.423812,104.984592,9.691091,0.305871,0.017473,0.178036,0.043705,8.287264,0.889586,2.378293,0.532713,149.241061,43.385223,102.012344,48.000859,0.835978,0.032144,0.048684,0.017808
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
472,1KCXOMHLPC,5,0.10,0.001,0.0100,10000,36.139055,0.213920,8.555179,0.117605,3.376653,0.062815,155.090564,2.776284,93.064491,1.981745,0.368106,0.004573,0.174105,0.010143,7.818147,0.156958,1.955541,0.075011,128.483077,4.720586,78.888842,3.589586,0.862498,0.004815,0.037988,0.002759,8.940524,0.556670,3.615607,0.288100,171.393330,14.608287,108.633054,11.889673,0.294017,0.017150,0.175083,0.040864,8.370553,0.938064,2.395018,0.564806,156.397243,49.746255,109.329672,56.292769,0.829934,0.035629,0.049832,0.018383
473,1PSSU56B6X,5,0.05,0.010,0.0100,10000,35.432682,0.112956,8.760462,0.109961,3.489426,0.054466,170.700506,2.853882,106.982628,2.459824,0.330248,0.004658,0.152130,0.008978,8.010112,0.156047,2.110692,0.067917,157.421525,7.056983,116.537593,7.396550,0.843828,0.005546,0.040300,0.002956,9.099307,0.545497,3.710239,0.246364,183.993955,15.724337,121.831974,19.102299,0.264659,0.017575,0.158691,0.033832,8.514431,1.071090,2.489858,0.690031,179.333343,66.194053,137.522917,78.523317,0.814111,0.044865,0.050907,0.018989
474,1TA4XGU2DS,13,0.10,0.100,0.0001,10000,23.793388,0.326572,9.961081,0.110627,4.152378,0.039402,255.913345,5.710124,199.355135,11.028547,0.121607,0.003369,0.045764,0.003480,12.564500,0.199214,4.151633,0.155911,501.467138,16.571171,499.693061,21.636939,0.587114,0.007469,0.084792,0.004594,10.089026,0.521946,4.309120,0.245606,260.453829,33.893271,209.781043,64.980678,0.085982,0.029638,0.067760,0.019371,12.986378,1.316806,4.458597,1.162611,519.461185,134.097020,511.820566,187.658954,0.551230,0.061974,0.074103,0.016621
475,1WYDAGJN2A,10,0.05,0.010,0.0010,1000,5.332068,0.295834,8.179846,0.081894,3.338038,0.046532,155.610889,2.334505,100.954219,2.224767,0.394397,0.005036,0.138904,0.009128,5.883027,0.094252,1.638071,0.034157,117.602785,5.432889,97.021560,5.903821,0.888708,0.003911,0.035281,0.002218,8.785186,0.500117,3.641472,0.247331,175.858537,14.755020,118.608665,19.440211,0.301309,0.021935,0.148066,0.034655,6.618467,1.003440,2.007343,0.658533,140.933123,63.276751,117.559816,77.003646,0.860336,0.041236,0.043184,0.016924


In [35]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy']]].mean(axis=1)

In [36]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_max_depth,param_max_features,param_min_samples_leaf,param_min_samples_split,param_n_estimators,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Train_MAE,Train_MSE,Train_R2,Valid_MAE,Valid_MSE,Valid_R2
357,ZC0ERU89FN,15,0.10,0.001,0.0001,5000,31.309914,0.418633,5.686733,0.041129,2.337464,0.045534,75.901892,0.976679,48.514944,1.353689,0.674522,0.005309,0.119907,0.005418,2.597571,0.018961,0.605305,0.018379,22.379849,0.519082,14.625618,0.552754,0.976379,0.000538,0.007807,0.000417,7.702730,0.384965,3.171069,0.337776,129.929950,8.975045,82.279584,10.194563,0.446035,0.026181,0.158723,0.048517,3.968645,0.721118,1.137518,0.485832,52.222719,31.217248,41.768541,38.016275,0.946801,0.020191,0.018561,0.007595,4.142152,49.140871,0.825450,5.835688,91.076335,0.696418
332,7AJFHDDQIO,15,0.10,0.001,0.0001,1000,6.251608,0.069263,5.688727,0.039740,2.338200,0.046976,75.973941,0.945811,48.547893,1.389744,0.674236,0.005241,0.119925,0.005519,2.609230,0.019514,0.610607,0.017566,22.566711,0.567729,14.740445,0.515973,0.976161,0.000536,0.008003,0.000465,7.699111,0.382052,3.170216,0.339854,129.788621,8.961198,82.219148,10.215789,0.446162,0.026070,0.159217,0.048383,3.976761,0.717522,1.142868,0.469887,52.146783,30.577119,41.468589,36.600865,0.946605,0.020304,0.018833,0.007897,4.148978,49.270326,0.825199,5.837936,90.967702,0.696384
467,174OMFFQIR,15,0.20,0.001,0.0001,1000,6.449018,0.139990,5.688727,0.039740,2.338200,0.046976,75.973941,0.945811,48.547893,1.389744,0.674236,0.005241,0.119925,0.005519,2.609230,0.019514,0.610607,0.017566,22.566711,0.567729,14.740445,0.515973,0.976161,0.000536,0.008003,0.000465,7.699111,0.382052,3.170216,0.339854,129.788621,8.961198,82.219148,10.215789,0.446162,0.026070,0.159217,0.048383,3.976761,0.717522,1.142868,0.469887,52.146783,30.577119,41.468589,36.600865,0.946605,0.020304,0.018833,0.007897,4.148978,49.270326,0.825199,5.837936,90.967702,0.696384
99,U5QZB2LGG0,15,0.10,0.001,0.0010,1000,8.900936,0.080087,5.688727,0.039740,2.338200,0.046976,75.973941,0.945811,48.547893,1.389744,0.674236,0.005241,0.119925,0.005519,2.609230,0.019514,0.610607,0.017566,22.566711,0.567729,14.740445,0.515973,0.976161,0.000536,0.008003,0.000465,7.699111,0.382052,3.170216,0.339854,129.788621,8.961198,82.219148,10.215789,0.446162,0.026070,0.159217,0.048383,3.976761,0.717522,1.142868,0.469887,52.146783,30.577119,41.468589,36.600865,0.946605,0.020304,0.018833,0.007897,4.148978,49.270326,0.825199,5.837936,90.967702,0.696384
354,Z63IPMAQI3,15,0.01,0.001,0.0001,1000,6.179547,0.078096,5.688727,0.039740,2.338200,0.046976,75.973941,0.945811,48.547893,1.389744,0.674236,0.005241,0.119925,0.005519,2.609230,0.019514,0.610607,0.017566,22.566711,0.567729,14.740445,0.515973,0.976161,0.000536,0.008003,0.000465,7.699111,0.382052,3.170216,0.339854,129.788621,8.961198,82.219148,10.215789,0.446162,0.026070,0.159217,0.048383,3.976761,0.717522,1.142868,0.469887,52.146783,30.577119,41.468589,36.600865,0.946605,0.020304,0.018833,0.007897,4.148978,49.270326,0.825199,5.837936,90.967702,0.696384
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
344,YKWF21MQ5Y,5,0.10,0.100,0.0100,1000,2.432660,0.033611,9.962744,0.110301,4.157739,0.040258,256.009166,5.762693,199.762929,11.147563,0.122355,0.003057,0.044646,0.003327,12.590681,0.201451,4.078790,0.167917,501.140429,17.335338,495.593881,22.901054,0.584667,0.007636,0.086440,0.004543,10.087823,0.522244,4.314012,0.245451,260.456162,34.018097,210.056740,65.238849,0.087077,0.029626,0.066950,0.019095,13.009722,1.305414,4.381979,1.142203,518.762314,133.117896,507.196496,185.142762,0.548609,0.062510,0.075965,0.016913,11.276712,378.574797,0.353511,11.548773,389.609238,0.317843
449,QJWKJN4SS8,10,0.05,0.100,0.0100,1000,2.403735,0.023417,9.962744,0.110301,4.157739,0.040258,256.009166,5.762693,199.762929,11.147563,0.122355,0.003057,0.044646,0.003327,12.590681,0.201451,4.078790,0.167917,501.140429,17.335338,495.593881,22.901054,0.584667,0.007636,0.086440,0.004543,10.087823,0.522244,4.314012,0.245451,260.456162,34.018097,210.056740,65.238849,0.087077,0.029626,0.066950,0.019095,13.009722,1.305414,4.381979,1.142203,518.762314,133.117896,507.196496,185.142762,0.548609,0.062510,0.075965,0.016913,11.276712,378.574797,0.353511,11.548773,389.609238,0.317843
203,2YJK2QMG18,5,0.20,0.100,0.0010,1000,2.460429,0.017265,9.962744,0.110301,4.157739,0.040258,256.009166,5.762693,199.762929,11.147563,0.122355,0.003057,0.044646,0.003327,12.590681,0.201451,4.078790,0.167917,501.140429,17.335338,495.593881,22.901054,0.584667,0.007636,0.086440,0.004543,10.087823,0.522244,4.314012,0.245451,260.456162,34.018097,210.056740,65.238849,0.087077,0.029626,0.066950,0.019095,13.009722,1.305414,4.381979,1.142203,518.762314,133.117896,507.196496,185.142762,0.548609,0.062510,0.075965,0.016913,11.276712,378.574797,0.353511,11.548773,389.609238,0.317843
179,JV9WC3J1T0,7,0.05,0.100,0.0100,1000,2.400998,0.041384,9.962744,0.110301,4.157739,0.040258,256.009166,5.762693,199.762929,11.147563,0.122355,0.003057,0.044646,0.003327,12.590681,0.201451,4.078790,0.167917,501.140429,17.335338,495.593881,22.901054,0.584667,0.007636,0.086440,0.004543,10.087823,0.522244,4.314012,0.245451,260.456162,34.018097,210.056740,65.238849,0.087077,0.029626,0.066950,0.019095,13.009722,1.305414,4.381979,1.142203,518.762314,133.117896,507.196496,185.142762,0.548609,0.062510,0.075965,0.016913,11.276712,378.574797,0.353511,11.548773,389.609238,0.317843


In [37]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'max_depth': 15, 'max_features': 0.1, 'min_samples_leaf': 0.001, 'min_samples_split': 0.0001, 'n_estimators': 5000}


## Best model

In [8]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_train_{}.npy'.format(DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'X_test_{}.npy'.format(DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_train_{}.npy'.format(DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', 'Y_test_{}.npy'.format(DATA_ID)))

In [10]:
# Setup the model with the best parameters
model = RandomForestRegressor(**best_params, random_state=0, n_jobs=-1, verbose=1)

model.fit(X_train, Y_train)

# Save the model
dump(model, os.path.join(RESULTS_PATH, DATA_ID, 'RF_{}'.format(HS_DATE), 'RF_best_model_{}_{}.joblib'.format(HS_DATE, DATA_ID))) 

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    1.9s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:    5.0s finished


['../../../results/0003_11042021/RF_12042021/RF_best_model_12042021_0003_11042021.joblib']

In [11]:
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

results = {
    'Train': {
        'MAE': mean_absolute_error(Y_train, train_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_train, train_preds, multioutput='raw_values'),
        'R2': r2_score(Y_train, train_preds, multioutput='raw_values')
    },
    'Test': {
        'MAE': mean_absolute_error(Y_test, test_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_test, test_preds, multioutput='raw_values'),
        'R2': r2_score(Y_test, test_preds, multioutput='raw_values')
    }       
    
}

# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy', 'Fz']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done 100 out of 100 | elapsed:    0.5s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done 100 out of 100 | elapsed:    0.1s finished


Train Fx MAE: 8.6461 ± 2.4048
Train Fx MSE: 151.6214 ± 77.8937
Train Fx R2: 0.5655 ± 0.0612
Train Fy MAE: 8.7525 ± 6.0748
Train Fy MSE: 238.7282 ± 265.7849
Train Fy R2: 0.4389 ± 0.1117
Train Fz MAE: 11.2891 ± 3.9985
Train Fz MSE: 300.2048 ± 246.7115
Train Fz R2: 0.5458 ± 0.0617
Test Fx MAE: 13.0032 ± 5.0952
Test Fx MSE: 311.5999 ± 249.7830
Test Fx R2: 0.3473 ± 0.2863
Test Fy MAE: 10.7406 ± 7.7562
Test Fy MSE: 378.7206 ± 481.6318
Test Fy R2: 0.3404 ± 0.1027
Test Fz MAE: 19.0024 ± 6.5255
Test Fz MSE: 687.2684 ± 593.5296
Test Fz R2: 0.3611 ± 0.2479


In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()