In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
from joblib import dump, load
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.neighbors import KNeighborsRegressor

## Configuration

In [2]:
JOINT = 'Hip'
FORCE_CELLS_PER_JOINT = {
    'Hip': [5, 6],
    'Knee': [3, 4, 7, 8],
    'Ankle': [1, 2]
}

CELLS = FORCE_CELLS_PER_JOINT[JOINT]

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0010_09082021'
# Hyperparameters search date
HS_DATE = '23082021'
# Number of folds in cross-validation
CV = 6

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0010_09082021


## Hyperparameters seach analysis

In [3]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, '{}_KNN_{}'.format(JOINT, HS_DATE), '{}_KNN_{}_*.json'.format(JOINT, HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 116


In [4]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_n_neighbors,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std
0,0NGIM8GBN7,481,0.012298,0.000706,7.631921,0.089827,0.323936,0.116494,116.756978,3.941339,7.005488,3.048815,0.536286,0.018818,0.081301,0.019485,8.610285,0.232924,5.292244,0.235765,260.949519,16.597412,234.714136,16.642524,0.388402,0.009955,0.034662,0.005275,8.221936,0.455411,0.452351,0.274091,134.895613,22.954946,16.001017,11.803391,0.430769,0.104825,0.138931,0.052912,9.333745,1.074878,5.837263,1.065305,299.951461,127.714966,271.331645,127.841786,0.304347,0.065346,0.055388,0.028827
1,0UJM1B06QQ,581,0.011377,0.000666,7.744493,0.092514,0.339233,0.114443,119.691331,4.125096,6.643466,3.080700,0.524129,0.019494,0.085489,0.019613,8.776222,0.269754,5.397814,0.269960,273.557768,19.605751,246.369957,19.624558,0.362762,0.008573,0.032477,0.005316,8.286545,0.440949,0.436125,0.290927,136.569091,22.652112,15.338761,10.928493,0.423940,0.102331,0.137251,0.056005,9.443991,1.136775,5.902998,1.120347,310.201775,138.862206,280.814377,138.890260,0.287277,0.068610,0.057060,0.028756
2,176LOEAZDW,951,0.012243,0.000635,8.056155,0.096465,0.352196,0.108320,127.812678,4.604760,6.087352,3.216755,0.490919,0.020761,0.095299,0.020201,9.189741,0.331247,5.646504,0.325022,305.796767,26.742014,275.980255,26.696239,0.294894,0.007819,0.029360,0.007085,8.488665,0.401062,0.381738,0.326391,141.802817,20.885660,13.902881,8.461630,0.404247,0.091613,0.130846,0.062388,9.721824,1.215426,6.054895,1.184878,336.652496,163.501921,305.077133,163.535513,0.239645,0.068781,0.061033,0.038305
3,1D5FG2ZXNC,721,0.012509,0.000467,7.877887,0.095534,0.349210,0.111718,123.166354,4.337863,6.326833,3.160431,0.509841,0.020086,0.090004,0.019941,8.960733,0.304221,5.510677,0.301780,287.721209,22.808158,259.419611,22.799779,0.333429,0.007894,0.030579,0.005939,8.368593,0.422978,0.414346,0.305347,138.725661,22.066101,14.711911,9.891580,0.415519,0.098663,0.135151,0.059013,9.565424,1.186188,5.968916,1.160742,321.869453,150.390606,291.541300,150.342562,0.266771,0.070482,0.058229,0.031956
4,1L6V53A015,61,0.016866,0.004130,6.367482,0.084169,0.193906,0.098647,86.073695,2.917122,6.900960,2.276327,0.659807,0.014225,0.053067,0.014161,7.045552,0.098150,4.323134,0.124234,171.163608,9.584111,152.966757,10.003723,0.587221,0.011529,0.035106,0.009882,7.803347,0.527365,0.488523,0.230941,125.683259,23.240990,19.061826,14.059686,0.470644,0.111906,0.137151,0.049360,8.737465,0.536656,5.504521,0.612198,250.289269,56.779805,226.274963,58.235190,0.385563,0.112384,0.075193,0.060829
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
111,PTHAVAM31U,21,0.012803,0.000799,5.570407,0.083823,0.218491,0.083722,69.725288,2.487284,4.456260,2.011969,0.723349,0.012052,0.047554,0.011520,6.175762,0.079940,3.791596,0.107035,139.385431,9.436864,124.619418,9.930252,0.664664,0.011447,0.027914,0.012322,7.848483,0.604193,0.469251,0.226034,129.084218,26.245222,19.772049,13.905685,0.456693,0.119465,0.139050,0.052025,8.808743,0.502071,5.568716,0.592213,259.314917,60.616437,235.374756,62.428423,0.372694,0.126983,0.078309,0.076105
112,PV839T0O6K,881,0.014544,0.003367,8.006283,0.096713,0.351833,0.108732,126.505894,4.541459,6.134817,3.209783,0.496221,0.020640,0.093891,0.020140,9.125211,0.326530,5.607898,0.321229,300.853468,25.522597,271.472057,25.484242,0.305659,0.007759,0.029407,0.006912,8.453373,0.406294,0.392020,0.319127,140.888233,21.241421,14.109684,8.853653,0.407506,0.093722,0.132345,0.061417,9.676754,1.212095,6.030217,1.183231,332.653467,160.250942,301.433413,160.252456,0.247379,0.069430,0.060161,0.036642
113,QNKYFOQWB4,801,0.013472,0.000805,7.945193,0.096284,0.351267,0.109998,124.914857,4.440344,6.218205,3.203628,0.502701,0.020339,0.092073,0.020119,9.047430,0.316972,5.562139,0.313052,294.671587,24.240672,265.811743,24.214846,0.318889,0.007726,0.029785,0.006435,8.412435,0.415386,0.402820,0.311137,139.844821,21.690423,14.373032,9.352649,0.411338,0.096294,0.133713,0.060180,9.623471,1.204221,6.000668,1.176124,327.606184,155.773990,296.811839,155.738088,0.256619,0.070152,0.059051,0.034126
114,QP3WNIL0YX,621,0.012377,0.001351,7.785499,0.093568,0.343696,0.113896,120.763202,4.189835,6.527899,3.100879,0.519705,0.019654,0.086953,0.019711,8.834447,0.281036,5.434240,0.280487,277.983910,20.622644,250.456274,20.633040,0.353694,0.008200,0.031783,0.005278,8.310968,0.435895,0.430126,0.295551,137.211527,22.500582,15.153938,10.644989,0.421370,0.101432,0.136742,0.056948,9.481574,1.154305,5.923679,1.135060,313.828610,142.561196,284.155210,142.565953,0.280979,0.069595,0.057779,0.029529


In [5]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy']]].mean(axis=1)

In [6]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_n_neighbors,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Train_MAE,Train_MSE,Train_R2,Valid_MAE,Valid_MSE,Valid_R2
42,R1ZGCC70L5,51,0.013155,0.000671,6.244074,0.084368,0.198903,0.095830,83.402803,2.896812,6.515418,2.221822,0.670199,0.013988,0.052110,0.013776,6.909422,0.094071,4.240911,0.121693,165.844029,9.672456,148.248843,10.123054,0.600503,0.011856,0.033542,0.010665,7.795535,0.540788,0.480641,0.232478,125.595337,23.536209,19.110772,13.881104,0.471063,0.113021,0.136525,0.049379,8.739591,0.518570,5.508388,0.601372,250.607355,56.076465,226.667210,57.697858,0.385237,0.115243,0.076153,0.063100,6.576748,124.623416,0.635351,8.267563,188.101346,0.428150
4,1L6V53A015,61,0.016866,0.004130,6.367482,0.084169,0.193906,0.098647,86.073695,2.917122,6.900960,2.276327,0.659807,0.014225,0.053067,0.014161,7.045552,0.098150,4.323134,0.124234,171.163608,9.584111,152.966757,10.003723,0.587221,0.011529,0.035106,0.009882,7.803347,0.527365,0.488523,0.230941,125.683259,23.240990,19.061826,14.059686,0.470644,0.111906,0.137151,0.049360,8.737465,0.536656,5.504521,0.612198,250.289269,56.779805,226.274963,58.235190,0.385563,0.112384,0.075193,0.060829,6.706517,128.618651,0.623514,8.270406,187.986264,0.428103
48,SOC8R3T9MS,41,0.011399,0.000589,6.090744,0.085616,0.205293,0.095867,80.161670,2.894325,5.958956,2.219906,0.682727,0.013689,0.051312,0.013550,6.743438,0.088373,4.142374,0.116746,159.666726,9.582277,142.781259,10.030936,0.616053,0.012517,0.031578,0.011024,7.792752,0.558626,0.476998,0.234727,125.833674,24.091908,19.427955,13.900027,0.470297,0.113987,0.136860,0.050006,8.749163,0.505424,5.521323,0.593742,252.177551,56.268885,228.329493,57.987805,0.383958,0.117617,0.076592,0.066376,6.417091,119.914198,0.649390,8.270957,189.005612,0.427128
5,1O2V27FD9C,71,0.013950,0.000843,6.469419,0.081144,0.191192,0.100287,88.368176,2.874597,7.184358,2.319692,0.650826,0.014279,0.054075,0.014672,7.158350,0.098333,4.391735,0.122635,176.102228,9.238774,157.388552,9.619451,0.575362,0.011549,0.035996,0.009146,7.813988,0.517714,0.496936,0.233011,125.869718,22.881453,18.939129,14.220862,0.469715,0.111160,0.138130,0.048936,8.738725,0.553064,5.500240,0.621867,250.611263,57.495579,226.477899,58.828497,0.384217,0.110810,0.074055,0.059404,6.813885,132.235202,0.613094,8.276356,188.240490,0.426966
93,G7DN6WPBRL,81,0.016491,0.002760,6.553433,0.081275,0.189257,0.102491,90.297778,2.919099,7.421731,2.357342,0.643269,0.014684,0.054919,0.014993,7.254554,0.103316,4.450507,0.124468,180.554591,9.413000,161.388368,9.740113,0.564862,0.011285,0.036658,0.008780,7.820986,0.507184,0.502854,0.229287,125.907920,22.624237,18.591937,14.383992,0.469270,0.110164,0.138652,0.048941,8.746878,0.567775,5.501736,0.630520,251.162410,58.781739,226.887864,60.026984,0.382585,0.108760,0.072927,0.057890,6.903994,135.426185,0.604066,8.283932,188.535165,0.425928
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
33,ISUXQBJXGS,971,0.011513,0.000924,8.069857,0.096579,0.351894,0.107828,128.176688,4.626127,6.088619,3.214140,0.489454,0.020849,0.095631,0.020212,9.207577,0.332753,5.657242,0.326174,307.121985,27.050756,277.186373,27.003794,0.291974,0.007975,0.029377,0.007178,8.498112,0.399764,0.378953,0.328391,142.047324,20.794756,13.848505,8.380947,0.403386,0.091023,0.130460,0.062646,9.734147,1.216168,6.061662,1.185388,337.720576,164.330690,306.046548,164.376813,0.237522,0.068538,0.061251,0.038843,8.638717,217.649337,0.390714,9.116129,239.883950,0.320454
37,L4HVKVNJS6,981,0.012465,0.001190,8.076619,0.096429,0.351906,0.107763,128.356720,4.635914,6.086421,3.216899,0.488729,0.020859,0.095808,0.020215,9.216548,0.333248,5.662703,0.326530,307.805680,27.209791,277.811167,27.162584,0.290498,0.008012,0.029352,0.007249,8.502636,0.398752,0.377305,0.329283,142.168211,20.736063,13.816766,8.336409,0.402968,0.090696,0.130236,0.062740,9.739971,1.217161,6.064825,1.186315,338.250294,164.773217,306.529854,164.825547,0.236513,0.068472,0.061345,0.039155,8.646583,218.081200,0.389614,9.121304,240.209253,0.319741
53,U7E9G6L4YA,990,0.009744,0.000756,8.082682,0.096385,0.351900,0.107479,128.522896,4.645562,6.088136,3.215225,0.488062,0.020877,0.095956,0.020211,9.224402,0.333705,5.667380,0.326788,308.409758,27.341372,278.362332,27.294011,0.289184,0.008045,0.029341,0.007306,8.506659,0.398004,0.376184,0.330279,142.274057,20.687223,13.786661,8.299108,0.402606,0.090382,0.130024,0.062868,9.745780,1.217325,6.068206,1.186386,338.734824,165.111112,306.970585,165.166606,0.235538,0.068357,0.061416,0.039348,8.653542,218.466327,0.388623,9.126220,240.504440,0.319072
103,MVFY90M8EC,991,0.012117,0.000659,8.083378,0.096399,0.351914,0.107482,128.541130,4.647067,6.088071,3.216334,0.487988,0.020884,0.095973,0.020217,9.225244,0.333711,5.667893,0.326791,308.472693,27.356500,278.419671,27.309536,0.289046,0.008048,0.029341,0.007319,8.507038,0.397876,0.376029,0.330257,142.285428,20.679673,13.782952,8.296945,0.402570,0.090344,0.129995,0.062875,9.746452,1.217269,6.068567,1.186369,338.792284,165.145571,307.022343,165.202151,0.235413,0.068341,0.061413,0.039368,8.654311,218.506911,0.388517,9.126745,240.538856,0.318992


In [7]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'n_neighbors': 51}


## Best model

In [8]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_X_train_{}.npy'.format(JOINT, DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_X_test_{}.npy'.format(JOINT, DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_Y_train_{}.npy'.format(JOINT, DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_Y_test_{}.npy'.format(JOINT, DATA_ID)))

In [9]:
# Setup the model with the best parameters
model = KNeighborsRegressor(**best_params, n_jobs=-1)

model.fit(X_train, Y_train)

# Save the model
dump(model, os.path.join(RESULTS_PATH, DATA_ID, '{}_KNN_{}'.format(JOINT, HS_DATE), '{}_KNN_best_model_{}_{}.joblib'.format(JOINT, HS_DATE, DATA_ID))) 

['../../../../results/0010_09082021/Hip_KNN_23082021/Hip_KNN_best_model_23082021_0010_09082021.joblib']

In [10]:
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

results = {
    'Train': {
        'MAE': mean_absolute_error(Y_train, train_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_train, train_preds, multioutput='raw_values'),
        'R2': r2_score(Y_train, train_preds, multioutput='raw_values')
    },
    'Test': {
        'MAE': mean_absolute_error(Y_test, test_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_test, test_preds, multioutput='raw_values'),
        'R2': r2_score(Y_test, test_preds, multioutput='raw_values')
    }       
    
}

# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results[subset][loss][i + f] for i in range(0, len(CELLS) * 2, 2)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

Train Fx MAE: 6.2814 ± 0.1982
Train Fx MSE: 84.1663 ± 6.6244
Train Fx R2: 0.6681 ± 0.0526
Train Fy MAE: 6.9751 ± 4.3101
Train Fy MSE: 167.9050 ± 150.3781
Train Fy R2: 0.5994 ± 0.0310
Test Fx MAE: 9.2759 ± 2.6061
Test Fx MSE: 179.0654 ± 94.2889
Test Fx R2: 0.3639 ± 0.3463
Test Fy MAE: 13.0549 ± 9.1347
Test Fy MSE: 875.8558 ± 833.5380
Test Fy R2: -0.0848 ± 0.3856


In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()