In [1]:
import pandas as pd
import numpy as np
import os
import glob
import json
from joblib import dump, load
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.neighbors import KNeighborsRegressor

## Configuration

In [21]:
JOINT = 'Ankle'
FORCE_CELLS_PER_JOINT = {
    'Hip': [5, 6],
    'Knee': [3, 4, 7, 8],
    'Ankle': [1, 2]
}

CELLS = FORCE_CELLS_PER_JOINT[JOINT]

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0013_09082021'
# Hyperparameters search date
HS_DATE = '24082021'
# Number of folds in cross-validation
CV = 6

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0013_09082021


## Hyperparameters seach analysis

In [22]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, '{}_KNN_{}'.format(JOINT, HS_DATE), '{}_KNN_{}_*.json'.format(JOINT, HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 100


In [23]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_n_neighbors,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std
0,0N8NZRZX7N,681,0.008977,0.000153,5.763313,0.090639,0.580157,0.140027,76.954131,3.271656,3.678013,2.195534,0.636019,0.017850,0.017917,0.010921,4.362110,0.188090,0.780746,0.220995,35.445451,2.391683,13.880254,3.120001,0.407906,0.009266,0.043460,0.012566,6.178748,0.753928,0.463345,0.298521,90.329114,40.306645,6.725860,3.043590,0.567014,0.057815,0.026811,0.020703,4.555732,0.453765,0.870515,0.537714,38.510064,8.378685,15.966851,10.094073,0.330507,0.090201,0.101536,0.065356
1,12AFJYNL89,461,0.008573,0.000381,5.205909,0.057428,0.572019,0.131067,59.774890,1.869755,3.711457,2.068120,0.716755,0.020698,0.018068,0.009953,4.088441,0.144447,0.682353,0.193021,31.953558,1.832718,12.298614,2.848745,0.464760,0.016502,0.035446,0.013922,5.696984,0.591182,0.474923,0.342854,74.129345,31.397173,5.976824,2.958760,0.636428,0.045816,0.039692,0.025064,4.306324,0.402912,0.765881,0.502001,35.290502,6.790074,14.260819,8.966730,0.377047,0.113725,0.106096,0.062524
2,1NGBPL2NTR,131,0.008741,0.000334,3.584864,0.080731,0.242936,0.076651,25.163625,1.040188,1.055585,0.665379,0.880577,0.011633,0.005193,0.001814,3.370164,0.018829,0.481739,0.160622,22.794858,0.798375,7.974462,2.155235,0.612680,0.028856,0.016128,0.009347,4.221738,0.306861,0.313015,0.190282,35.828071,9.919867,3.609869,3.689250,0.807005,0.055187,0.036401,0.031232,3.765976,0.203875,0.568085,0.381565,27.814745,3.761179,9.963328,6.594258,0.491103,0.141152,0.083628,0.047794
3,1YCME92LUV,71,0.008989,0.000569,3.045363,0.080588,0.097615,0.037866,18.183788,0.960762,0.527184,0.313441,0.913855,0.007048,0.002105,0.001445,2.946435,0.023470,0.379270,0.139950,17.999917,0.513302,5.773402,1.690100,0.690686,0.027605,0.012252,0.005803,3.813036,0.291111,0.205275,0.165341,28.124047,6.299627,3.467780,3.153539,0.844844,0.050932,0.031724,0.029360,3.471392,0.123883,0.503986,0.309142,23.997532,2.735422,7.956265,5.776324,0.560480,0.111892,0.052159,0.012829
4,223RZ7XI4X,181,0.008236,0.000834,3.937763,0.090342,0.359250,0.094230,30.841162,1.576445,1.605480,1.115012,0.853428,0.016857,0.007977,0.004529,3.579800,0.048373,0.543748,0.164504,25.200940,0.932763,9.137865,2.367608,0.574022,0.025785,0.019343,0.013579,4.573528,0.314119,0.364450,0.269606,43.400565,14.193351,4.234504,3.518006,0.772379,0.052673,0.040751,0.030168,3.932731,0.254699,0.636835,0.406568,29.923456,4.246260,11.212230,6.946917,0.454769,0.148765,0.096813,0.060532
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,89CO7AIZK3,221,0.008711,0.000531,4.180807,0.094047,0.426538,0.100341,35.227107,2.015224,2.029673,1.564709,0.832530,0.020197,0.010132,0.006798,3.702412,0.069229,0.582352,0.168303,26.685398,1.068064,9.885464,2.521073,0.550389,0.023617,0.022596,0.015873,4.804603,0.345596,0.403660,0.315030,49.080959,17.241173,4.723263,3.283099,0.746175,0.051479,0.043670,0.030291,4.023174,0.286321,0.668068,0.425190,31.070940,4.638393,11.851586,7.279330,0.436245,0.148967,0.101800,0.063763
96,8QR60TMV95,671,0.008947,0.000333,5.740629,0.089006,0.580937,0.139934,76.213880,3.199593,3.708493,2.191651,0.639502,0.017873,0.018059,0.010901,4.349501,0.186618,0.775812,0.220332,35.291477,2.367205,13.811526,3.110689,0.410421,0.009550,0.043123,0.012677,6.158131,0.747580,0.463404,0.300824,89.622547,39.960005,6.644046,3.100045,0.570124,0.057216,0.027303,0.020806,4.544131,0.451304,0.865936,0.536160,38.360257,8.305995,15.893765,10.043533,0.332754,0.091049,0.101772,0.065182
97,9H98GC7CW0,241,0.008099,0.000474,4.290591,0.095062,0.453482,0.104935,37.376845,2.186350,2.245199,1.771954,0.822310,0.021527,0.011214,0.007851,3.751192,0.078610,0.595514,0.170614,27.294889,1.143328,10.161203,2.573230,0.540483,0.023092,0.024022,0.015941,4.905960,0.361545,0.415363,0.330311,51.676006,18.629371,4.868631,3.119793,0.734340,0.050632,0.044376,0.030514,4.056684,0.301032,0.679733,0.434556,31.533320,4.841272,12.112009,7.464660,0.429364,0.147352,0.103359,0.063737
98,APJQAYJOCR,51,0.009101,0.000862,2.788349,0.079449,0.068986,0.034275,15.550323,0.854796,0.508754,0.236554,0.926345,0.005877,0.002443,0.000775,2.713174,0.032337,0.330672,0.125394,15.618731,0.345836,4.814500,1.452738,0.730372,0.024460,0.011120,0.006747,3.686913,0.220051,0.205982,0.121687,25.831074,5.046507,3.205366,2.534232,0.857314,0.045423,0.029225,0.025986,3.346175,0.115329,0.505719,0.266257,22.483575,2.518105,7.737380,4.944778,0.588830,0.099373,0.038475,0.014104


In [24]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy']]].mean(axis=1)

In [25]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_n_neighbors,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Train_MAE,Train_MSE,Train_R2,Valid_MAE,Valid_MSE,Valid_R2
30,RCOOTPXX3U,11,0.008654,0.000717,1.769110,0.062271,0.093510,0.017838,7.154559,0.584188,0.748275,0.259173,0.966108,0.003249,0.003490,0.001467,1.678717,0.053382,0.146726,0.045002,6.679076,0.340874,1.399202,0.442022,0.880283,0.013456,0.014643,0.003359,3.298697,0.169622,0.210575,0.061475,22.782485,3.994877,2.714121,0.951702,0.874290,0.037282,0.026557,0.023922,2.968040,0.106152,0.432744,0.096158,19.580672,0.985800,6.533414,2.481276,0.633602,0.085905,0.042940,0.018949,1.723914,6.916818,0.923196,3.133368,21.181579,0.753946
21,ODR8869TNA,21,0.008766,0.000321,2.172245,0.078378,0.057677,0.014355,10.101524,0.701694,0.565880,0.324092,0.952148,0.004158,0.002659,0.001872,2.094326,0.041462,0.211983,0.077837,9.860315,0.207112,2.516152,0.759083,0.826292,0.017324,0.013039,0.005672,3.424269,0.179372,0.228492,0.061742,23.118497,4.096520,2.751233,1.059818,0.872766,0.037041,0.026322,0.023184,3.087251,0.082152,0.474347,0.159146,20.214399,1.531845,7.151613,3.270834,0.627014,0.086519,0.040311,0.017659,2.133286,9.980919,0.889220,3.255760,21.666448,0.749890
31,S0ESC660HC,31,0.008358,0.000144,2.431695,0.080226,0.054197,0.023377,12.327380,0.754130,0.543615,0.160782,0.941602,0.004863,0.002593,0.001172,2.362192,0.034206,0.262073,0.098855,12.202881,0.217772,3.488859,0.990711,0.787483,0.019911,0.008958,0.007055,3.532824,0.192641,0.212915,0.081796,23.964377,4.320020,2.760985,1.466034,0.868017,0.039500,0.026360,0.023198,3.190741,0.106767,0.498711,0.198532,20.973819,2.155641,7.433679,3.938252,0.615420,0.090559,0.033135,0.018670,2.396944,12.265130,0.864543,3.361783,22.469098,0.741719
72,TYKN1KOEI1,41,0.008462,0.000869,2.628335,0.079595,0.059141,0.031196,14.072086,0.794691,0.551784,0.113103,0.933342,0.005366,0.002642,0.000707,2.556572,0.033605,0.300035,0.112311,14.060985,0.284381,4.214399,1.202414,0.756414,0.022498,0.009012,0.006879,3.614548,0.203071,0.206993,0.101886,24.890135,4.671621,2.913787,1.946560,0.862860,0.042048,0.027310,0.023930,3.274318,0.116875,0.501592,0.236013,21.696575,2.445102,7.584151,4.491771,0.603175,0.093872,0.030178,0.021949,2.592454,14.066535,0.844878,3.444433,23.293355,0.733018
98,APJQAYJOCR,51,0.009101,0.000862,2.788349,0.079449,0.068986,0.034275,15.550323,0.854796,0.508754,0.236554,0.926345,0.005877,0.002443,0.000775,2.713174,0.032337,0.330672,0.125394,15.618731,0.345836,4.814500,1.452738,0.730372,0.024460,0.011120,0.006747,3.686913,0.220051,0.205982,0.121687,25.831074,5.046507,3.205366,2.534232,0.857314,0.045423,0.029225,0.025986,3.346175,0.115329,0.505719,0.266257,22.483575,2.518105,7.737380,4.944778,0.588830,0.099373,0.038475,0.014104,2.750761,15.584527,0.828358,3.516544,24.157324,0.723072
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
38,TQC9KRH0EY,951,0.009014,0.000177,6.331190,0.128688,0.538365,0.123822,95.758005,5.073448,2.335184,1.675658,0.547471,0.018499,0.011690,0.008118,4.642195,0.212268,0.873304,0.235046,39.094782,2.906663,15.359902,3.410409,0.347372,0.004877,0.048656,0.011981,6.692730,0.907090,0.454475,0.244513,108.198442,48.092309,8.784728,3.203529,0.486066,0.071020,0.027321,0.005793,4.819770,0.495512,0.970792,0.565192,42.000069,9.766202,17.476656,11.243496,0.276007,0.074951,0.100186,0.059006,5.486693,67.426393,0.447421,5.756250,75.099255,0.381037
23,P3R1UW6IB9,961,0.008080,0.000716,6.350216,0.130400,0.536849,0.123224,96.407387,5.145409,2.286381,1.648407,0.544417,0.018505,0.011456,0.007962,4.650284,0.212946,0.875159,0.235181,39.211735,2.926297,15.402128,3.422778,0.345404,0.004795,0.048719,0.011973,6.710245,0.911347,0.454365,0.243636,108.811817,48.309968,8.852351,3.289921,0.483199,0.071304,0.027334,0.005481,4.827739,0.496784,0.972965,0.566236,42.117105,9.810824,17.521573,11.288333,0.274167,0.074494,0.100190,0.058385,5.500250,67.809561,0.444910,5.768992,75.464461,0.378683
69,BEY4I38I3G,971,0.008430,0.000571,6.369235,0.132492,0.535286,0.122461,97.055757,5.223696,2.236540,1.614060,0.541371,0.018461,0.011213,0.007804,4.658473,0.213543,0.877346,0.235490,39.331633,2.943290,15.447840,3.434618,0.343399,0.004754,0.048834,0.011980,6.727380,0.915707,0.454444,0.243120,109.417421,48.529871,8.919052,3.371981,0.480378,0.071610,0.027353,0.005249,4.835163,0.498075,0.975019,0.567398,42.227017,9.856215,17.562825,11.331773,0.272479,0.073867,0.100130,0.057717,5.513854,68.193695,0.442385,5.781272,75.822219,0.376428
14,LH8ML4T2N4,981,0.008167,0.000427,6.388431,0.134517,0.533936,0.121879,97.706375,5.300249,2.194060,1.580930,0.538314,0.018420,0.011006,0.007639,4.666466,0.214133,0.879309,0.235816,39.448379,2.959423,15.492667,3.447013,0.341449,0.004742,0.048952,0.011963,6.744715,0.919621,0.454273,0.242617,110.021030,48.749414,8.986062,3.450180,0.477554,0.071902,0.027419,0.005092,4.842512,0.499204,0.976694,0.568785,42.337176,9.900696,17.602173,11.378860,0.270790,0.073194,0.100011,0.056959,5.527449,68.577377,0.439881,5.793613,76.179103,0.374172


In [26]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'n_neighbors': 11}


## Best model

In [27]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_X_train_{}.npy'.format(JOINT, DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_X_test_{}.npy'.format(JOINT, DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_Y_train_{}.npy'.format(JOINT, DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_Y_test_{}.npy'.format(JOINT, DATA_ID)))

In [28]:
# Setup the model with the best parameters
model = KNeighborsRegressor(**best_params, n_jobs=-1)

model.fit(X_train, Y_train)

# Save the model
dump(model, os.path.join(RESULTS_PATH, DATA_ID, '{}_KNN_{}'.format(JOINT, HS_DATE), '{}_KNN_best_model_{}_{}.joblib'.format(JOINT, HS_DATE, DATA_ID))) 

['../../../../results/0013_09082021/Ankle_KNN_24082021/Ankle_KNN_best_model_24082021_0013_09082021.joblib']

In [29]:
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

results = {
    'Train': {
        'MAE': mean_absolute_error(Y_train, train_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_train, train_preds, multioutput='raw_values'),
        'R2': r2_score(Y_train, train_preds, multioutput='raw_values')
    },
    'Test': {
        'MAE': mean_absolute_error(Y_test, test_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_test, test_preds, multioutput='raw_values'),
        'R2': r2_score(Y_test, test_preds, multioutput='raw_values')
    }       
    
}

# # Display the score for each axis of each force cell
# for subset in ['Train', 'Test']:
#     for f, force in enumerate(['Fx', 'Fy', 'Fz']):
#         for c in range(N_CELLS):
#             for loss in ['MAE', 'MSE', 'R2']:
#                 scores = [results[subset][loss][i + f] for i in range(0, N_CELLS * 3, 3)]
#                 print('{} {}{}{} {}: {:.4f}'.format(subset, force[0], c + 1, force[-1], loss, scores[c]))
# print('\n')

# Display the score mean and standard deviation of each axis
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results[subset][loss][i + f] for i in range(0, len(CELLS) * 2, 2)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

Train Fx MAE: 1.7668 ± 0.0883
Train Fx MSE: 7.1628 ± 0.8324
Train Fx R2: 0.9664 ± 0.0040
Train Fy MAE: 1.6812 ± 0.1318
Train Fy MSE: 6.6982 ± 1.3221
Train Fy R2: 0.8803 ± 0.0170
Test Fx MAE: 3.8377 ± 0.1204
Test Fx MSE: 28.7426 ± 2.4439
Test Fx R2: 0.8249 ± 0.0026
Test Fy MAE: 3.2209 ± 0.1716
Test Fy MSE: 25.3148 ± 8.8965
Test Fy R2: 0.3795 ± 0.2599


In [30]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()