In [1]:
import pandas as pd
import numpy as np
import os
import glob
import time
import json
from joblib import dump, load
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.multioutput import MultiOutputRegressor
from sklearn.svm import SVR

## Configuration

In [2]:
JOINT = 'Ankle'
FORCE_CELLS_PER_JOINT = {
    'Hip': [5, 6],
    'Knee': [3, 4, 7, 8],
    'Ankle': [1, 2]
}

CELLS = FORCE_CELLS_PER_JOINT[JOINT]

# Path where the results are stored
RESULTS_PATH = '../../../../results'
# ID of the training and validation data resulting from this notebook, stored in RESULTS_PATH
DATA_ID = '0013_09082021'
# Hyperparameters search date
HS_DATE = '24082021'
# Number of folds in cross-validation
CV = 4

print('Model trained with data: ' + DATA_ID)

pd.set_option('display.max_columns', None)

Model trained with data: 0013_09082021


## Hyperparameters seach analysis

In [3]:
results_files_ls = glob.glob(os.path.join(RESULTS_PATH, DATA_ID, '{}_SVM_{}'.format(JOINT, HS_DATE), '{}_SVM_{}_*.json'.format(JOINT, HS_DATE)))

print('Number of results files: {}'.format(len(results_files_ls)))

Number of results files: 144


In [4]:
# Load all the results and generates a pandas dataframe
results_ls = []
for results_file in results_files_ls:
    with open(results_file) as json_file:
        results_dict = json.load(json_file)
        
    dict_aux = {}
    dict_aux['params_ID'] = results_dict['id']
    for key, value in results_dict['parameters'].items():
        dict_aux['param_' + key] = value
    for key, value in results_dict['cv_results'].items():
        dict_aux['__'.join([key, 'mean'])] = np.mean(value)
        dict_aux['__'.join([key, 'std'])] = np.std(value)

    results_ls.append(dict_aux)
        
results_df = pd.DataFrame(results_ls)
results_df

Unnamed: 0,params_ID,param_C,param_epsilon,param_kernel,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std
0,TSLVNCM7RS,1.0,0.5,rbf,0.996294,0.055279,3.386590,0.114827,0.146111,0.074839,22.532091,1.283803,1.080087,0.584988,0.893270,0.008854,0.005280,0.002280,3.553511,0.128985,0.551105,0.173759,29.846860,2.645085,12.887297,3.638723,0.510298,0.010628,0.058891,0.021108,3.819991,0.346002,0.265561,0.122248,28.908412,7.607216,3.088718,3.937665,0.846829,0.034808,0.027550,0.022298,3.771095,0.456560,0.738970,0.308599,31.962138,8.753185,15.392481,8.134586,0.469371,0.049870,0.072840,0.031864
1,393KBF7IGF,0.5,0.1,poly,1.060976,0.058176,3.981570,0.160135,0.132843,0.075224,36.730401,2.840776,8.964961,2.873588,0.826675,0.010448,0.041343,0.008153,4.103620,0.166457,0.677114,0.174563,35.679083,2.599925,13.848798,3.637324,0.403502,0.008446,0.040808,0.023103,4.245052,0.368378,0.414775,0.218675,39.592594,8.812555,11.793263,8.081042,0.796199,0.034064,0.047422,0.024182,4.298964,0.609551,0.730103,0.518636,38.209615,11.157348,15.229880,10.454640,0.348352,0.058507,0.062106,0.027997
2,AZDJXBOCJV,1.2,0.1,poly,1.272165,0.055735,3.913371,0.152868,0.126459,0.078646,35.382150,2.749581,8.373452,2.817962,0.833060,0.009823,0.038564,0.008202,4.012581,0.159212,0.636896,0.161892,33.848003,2.361307,12.560449,3.322324,0.430355,0.007547,0.027742,0.022006,4.209806,0.355301,0.436202,0.204247,38.562173,8.455328,11.430493,8.157139,0.801480,0.033018,0.044633,0.026050,4.231958,0.578317,0.665158,0.478379,36.507068,10.433247,13.438912,9.367923,0.366841,0.052669,0.037109,0.029240
3,5URBXMLNF2,0.3,0.5,rbf,0.969620,0.036925,3.946264,0.131589,0.132960,0.085581,31.836742,2.061626,2.092429,1.218479,0.849417,0.011117,0.008257,0.005990,3.802608,0.139908,0.564173,0.176819,32.889562,2.980871,14.291355,4.025037,0.460882,0.009744,0.066414,0.022441,4.317461,0.365414,0.289824,0.178661,38.969598,11.812110,3.525764,5.022282,0.798397,0.035046,0.028079,0.028715,4.013064,0.533947,0.786515,0.325396,34.832543,10.166836,16.902491,8.606361,0.424367,0.061247,0.081850,0.027580
4,T9VTZWRN67,0.8,0.5,poly,1.120822,0.096064,3.943732,0.156295,0.129592,0.077735,35.892693,2.793692,8.555030,2.855021,0.830643,0.010144,0.039409,0.008269,4.054106,0.160773,0.654549,0.167416,34.582144,2.448645,13.128804,3.401465,0.419869,0.008169,0.034030,0.021662,4.222481,0.361900,0.427006,0.214244,38.952145,8.571265,11.502236,8.123368,0.799327,0.033654,0.045837,0.025032,4.261117,0.583578,0.693957,0.495923,37.162453,10.704475,14.212231,9.868175,0.360258,0.056263,0.048795,0.025942
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
139,OH890VQUQE,0.3,0.3,rbf,0.994358,0.027152,3.943619,0.132706,0.134431,0.086018,31.824986,2.079249,2.089660,1.184147,0.849488,0.010990,0.008103,0.006039,3.799082,0.138102,0.565604,0.176059,32.872670,2.948845,14.276782,3.997734,0.461060,0.010335,0.066299,0.022375,4.310389,0.366934,0.284310,0.182211,38.860395,11.782741,3.524092,4.946544,0.799111,0.034697,0.027664,0.027905,4.012170,0.537411,0.789296,0.326780,34.854506,10.239521,16.929197,8.655644,0.424327,0.062471,0.082085,0.028435
140,KDNKQZ36AZ,1.2,0.9,sigmoid,1.604540,0.042314,36.892984,6.169206,1.185700,0.188378,6500.746805,1741.371501,98.240646,73.806690,-29.425439,7.576810,0.729738,0.251144,33.861670,5.896230,0.481482,0.173865,5447.472957,1517.387934,194.320846,25.282766,-105.860540,32.622664,37.376505,8.879396,36.712994,5.864146,1.025450,0.454423,6299.832946,1731.319606,149.434683,103.506644,-32.228455,7.561313,3.018109,3.736385,33.557259,5.427909,0.796837,0.188851,5309.605792,1640.817794,201.469502,73.044067,-107.428083,19.783788,38.037335,20.209822
141,2CHCOZDFOL,0.2,0.7,rbf,0.923356,0.039763,4.229178,0.135392,0.172434,0.084644,37.422620,2.372574,1.962094,1.190033,0.823055,0.011654,0.008710,0.004572,3.921610,0.147887,0.579966,0.181608,34.150131,3.028852,14.788646,4.094895,0.439778,0.008860,0.068027,0.021657,4.584255,0.390925,0.326735,0.169516,45.122834,14.836235,4.334519,5.516209,0.769694,0.035501,0.029096,0.026689,4.132349,0.564331,0.819158,0.336109,36.195502,10.727548,17.624365,8.828883,0.403057,0.062734,0.086097,0.027230
142,IOA1XAGLSI,1.0,0.7,linear,0.820938,0.177420,4.330342,0.164222,0.299513,0.100499,39.647398,3.051886,1.258679,0.887936,0.812882,0.009156,0.006339,0.002815,4.691817,0.236446,0.977852,0.269559,46.095778,4.624127,23.775436,5.601880,0.267533,0.016373,0.164514,0.025236,4.465091,0.462471,0.376316,0.197953,41.813500,10.028578,3.138403,1.663634,0.783235,0.035061,0.020168,0.014749,4.954452,0.696399,1.114591,0.669386,51.161395,16.929837,27.792064,18.018402,0.191273,0.087348,0.193768,0.104953


In [5]:
# Sum up the scores by force axis in only one sortable score
for subset in ['Train', 'Valid']:
    for loss in ['MAE', 'MSE', 'R2']:
        results_df[subset + '_' + loss] = results_df[[subset + '_' + force + '_' + loss + '_mean__mean' for force in ['Fx', 'Fy']]].mean(axis=1)

In [6]:
# Sort the dataframe by the most relevant score
results_df = results_df.sort_values(['Valid_R2'], ascending=False)
results_df

Unnamed: 0,params_ID,param_C,param_epsilon,param_kernel,fit_time__mean,fit_time__std,Train_Fx_MAE_mean__mean,Train_Fx_MAE_mean__std,Train_Fx_MAE_std__mean,Train_Fx_MAE_std__std,Train_Fx_MSE_mean__mean,Train_Fx_MSE_mean__std,Train_Fx_MSE_std__mean,Train_Fx_MSE_std__std,Train_Fx_R2_mean__mean,Train_Fx_R2_mean__std,Train_Fx_R2_std__mean,Train_Fx_R2_std__std,Train_Fy_MAE_mean__mean,Train_Fy_MAE_mean__std,Train_Fy_MAE_std__mean,Train_Fy_MAE_std__std,Train_Fy_MSE_mean__mean,Train_Fy_MSE_mean__std,Train_Fy_MSE_std__mean,Train_Fy_MSE_std__std,Train_Fy_R2_mean__mean,Train_Fy_R2_mean__std,Train_Fy_R2_std__mean,Train_Fy_R2_std__std,Valid_Fx_MAE_mean__mean,Valid_Fx_MAE_mean__std,Valid_Fx_MAE_std__mean,Valid_Fx_MAE_std__std,Valid_Fx_MSE_mean__mean,Valid_Fx_MSE_mean__std,Valid_Fx_MSE_std__mean,Valid_Fx_MSE_std__std,Valid_Fx_R2_mean__mean,Valid_Fx_R2_mean__std,Valid_Fx_R2_std__mean,Valid_Fx_R2_std__std,Valid_Fy_MAE_mean__mean,Valid_Fy_MAE_mean__std,Valid_Fy_MAE_std__mean,Valid_Fy_MAE_std__std,Valid_Fy_MSE_mean__mean,Valid_Fy_MSE_mean__std,Valid_Fy_MSE_std__mean,Valid_Fy_MSE_std__std,Valid_Fy_R2_mean__mean,Valid_Fy_R2_mean__std,Valid_Fy_R2_std__mean,Valid_Fy_R2_std__std,Train_MAE,Train_MSE,Train_R2,Valid_MAE,Valid_MSE,Valid_R2
60,4VW17ZKI1P,1.2,0.1,rbf,1.043625,0.015231,3.327384,0.111840,0.159291,0.074409,21.709501,1.232419,1.246225,0.525311,0.897115,0.009074,0.006094,0.002259,3.521866,0.126625,0.550438,0.174334,29.469280,2.552392,12.675677,3.580447,0.516152,0.010693,0.057279,0.020893,3.778255,0.339998,0.253389,0.120709,28.071217,7.302022,2.969409,3.561791,0.851332,0.033634,0.026541,0.020136,3.745161,0.450274,0.736218,0.310507,31.628080,8.680269,15.189647,8.132592,0.474765,0.048770,0.070960,0.033142,3.424625,25.589391,0.706633,3.761708,29.849648,0.663049
81,S0GOBUB2X4,1.2,0.5,rbf,0.920991,0.008569,3.330336,0.112888,0.157757,0.073645,21.698173,1.238726,1.208494,0.518924,0.897186,0.008911,0.005919,0.002128,3.524540,0.126524,0.549184,0.173348,29.433826,2.574172,12.670606,3.567410,0.516808,0.010368,0.057375,0.020567,3.779412,0.342741,0.255292,0.119932,28.167074,7.398242,3.028493,3.614778,0.850744,0.034170,0.026939,0.020575,3.745352,0.447272,0.738121,0.306149,31.601514,8.604996,15.219667,8.060350,0.475249,0.048482,0.072077,0.032313,3.427438,25.565999,0.706997,3.762382,29.884294,0.662997
85,WNBP3MMB2L,1.2,0.3,rbf,1.088390,0.023259,3.329498,0.113398,0.158405,0.074337,21.723429,1.245226,1.227089,0.492468,0.897058,0.009013,0.005997,0.002036,3.522409,0.126481,0.550038,0.173635,29.468085,2.563341,12.685354,3.581545,0.516247,0.010604,0.057465,0.020898,3.778782,0.341391,0.254166,0.120120,28.120205,7.371185,3.005703,3.570104,0.851078,0.033796,0.026779,0.020328,3.745105,0.447883,0.737519,0.310357,31.635005,8.665673,15.223302,8.148882,0.474854,0.048738,0.071553,0.033549,3.425954,25.595757,0.706652,3.761943,29.877605,0.662966
66,G0KWFBOXF8,1.2,0.2,rbf,1.077978,0.103760,3.328479,0.113020,0.158584,0.073772,21.714307,1.243286,1.236126,0.507263,0.897096,0.009064,0.006044,0.002141,3.522185,0.126302,0.550313,0.174292,29.467933,2.562137,12.680178,3.587000,0.516216,0.010612,0.057368,0.021099,3.780203,0.340447,0.255558,0.118972,28.105903,7.342588,3.013551,3.571960,0.851097,0.033810,0.026857,0.020495,3.745072,0.449101,0.736159,0.310772,31.623558,8.662339,15.193613,8.133497,0.474817,0.048587,0.071092,0.033013,3.425332,25.591120,0.706656,3.762638,29.864731,0.662957
5,5O96EZF66F,1.2,0.7,rbf,1.010425,0.156295,3.330158,0.112733,0.158484,0.073479,21.653935,1.235735,1.198001,0.497414,0.897394,0.008885,0.005861,0.002073,3.527948,0.126272,0.547243,0.173188,29.397001,2.583680,12.647773,3.567262,0.517379,0.010722,0.057162,0.020566,3.783638,0.342606,0.255046,0.117076,28.208918,7.438875,3.015379,3.567980,0.850493,0.034288,0.027003,0.020900,3.750955,0.444610,0.735440,0.302330,31.591262,8.577403,15.175155,8.011851,0.474998,0.048294,0.071413,0.031375,3.429053,25.525468,0.707386,3.767297,29.900090,0.662746
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
117,G4244JDKZJ,1.2,0.7,sigmoid,1.647856,0.068830,36.897172,6.169636,1.183606,0.190532,6503.434252,1740.949516,99.077394,71.045665,-29.438302,7.574415,0.733391,0.259917,33.865662,5.893429,0.481982,0.173607,5451.202754,1516.949562,195.398387,23.554305,-105.937779,32.622387,37.421361,8.894282,36.722238,5.862882,1.017910,0.454670,6302.360392,1731.507320,150.132631,105.080594,-32.238034,7.554703,3.005367,3.728441,33.558273,5.430420,0.796670,0.189757,5313.173235,1640.125212,202.033494,73.458720,-107.507619,19.773806,38.072197,20.228038,35.381417,5977.318503,-67.688041,35.140256,5807.766813,-69.872826
7,TJGYGV6VK9,1.2,0.5,sigmoid,1.719398,0.051815,36.898575,6.166128,1.187427,0.188249,6503.205329,1736.985827,100.800198,74.833698,-29.436835,7.550771,0.741254,0.266041,33.868536,5.894598,0.483140,0.173390,5453.194634,1517.379374,193.997552,23.587786,-105.967794,32.631553,37.407315,8.897810,36.725202,5.853959,1.023961,0.458357,6301.287320,1725.307322,152.148031,103.025890,-32.235872,7.539452,3.024592,3.719870,33.564173,5.429613,0.795641,0.184926,5315.225507,1640.664527,200.995886,72.606382,-107.543032,19.781368,38.068579,20.213475,35.383556,5978.199981,-67.702315,35.144688,5808.256414,-69.889452
24,PB4N926EI6,1.2,0.3,sigmoid,1.689616,0.080936,36.901428,6.167190,1.188639,0.184770,6505.136974,1737.251828,99.784444,76.872432,-29.445328,7.550343,0.736306,0.261667,33.870171,5.894306,0.482968,0.173236,5454.957543,1518.784250,192.987985,23.965580,-105.997858,32.662047,37.399233,8.902723,36.730619,5.852072,1.029330,0.457651,6303.080495,1725.144452,151.439073,101.497779,-32.249081,7.546433,3.041358,3.724258,33.569249,5.429882,0.792042,0.180718,5317.663349,1643.730768,200.215270,72.820399,-107.580660,19.831795,38.064824,20.227323,35.385800,5980.047258,-67.721593,35.149934,5810.371922,-69.914870
87,EGWT0NONIY,1.2,0.2,sigmoid,1.666694,0.070777,36.900613,6.168856,1.189110,0.186129,6504.538358,1738.331187,100.164065,76.649198,-29.442288,7.556823,0.737812,0.255373,33.871493,5.894421,0.484536,0.171865,5455.424979,1518.731545,192.307512,24.253846,-106.004124,32.659605,37.389419,8.917975,36.731773,5.851921,1.029177,0.456824,6302.549517,1726.556842,151.374369,101.016030,-32.246270,7.550306,3.041819,3.723951,33.570653,5.430580,0.794216,0.178340,5318.559573,1643.562672,199.172600,73.016661,-107.586011,19.829238,38.025458,20.235225,35.386053,5979.981668,-67.723206,35.151213,5810.554545,-69.916140


In [7]:
best_params = dict(results_df.iloc[0][[col for col in results_df.columns if 'param_' in col]])
best_params = {key.replace('param_', ''): value for key, value in best_params.items()}
print('Best parameters: {}'.format(best_params))

Best parameters: {'C': 1.2, 'epsilon': 0.1, 'kernel': 'rbf'}


## Best model

In [8]:
# Load data
X_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_X_train_{}.npy'.format(JOINT, DATA_ID)))
X_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_X_test_{}.npy'.format(JOINT, DATA_ID)))
Y_train = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_Y_train_{}.npy'.format(JOINT, DATA_ID)))
Y_test = np.load(os.path.join(RESULTS_PATH, DATA_ID, 'data', '{}_Y_test_{}.npy'.format(JOINT, DATA_ID)))

In [9]:
# Setup the model
model = MultiOutputRegressor(SVR(**best_params, verbose=0), n_jobs=-1)

# Train the model
t_start = time.time()
model.fit(X_train, Y_train)
t_end = time.time()

print('Training time: {:.4f}'.format(t_end - t_start))

# Get the scores
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

results = {
    'Train': {
        'MAE': mean_absolute_error(Y_train, train_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_train, train_preds, multioutput='raw_values'),
        'R2': r2_score(Y_train, train_preds, multioutput='raw_values')
    },
    'Test': {
        'MAE': mean_absolute_error(Y_test, test_preds, multioutput='raw_values'),
        'MSE': mean_squared_error(Y_test, test_preds, multioutput='raw_values'),
        'R2': r2_score(Y_test, test_preds, multioutput='raw_values')
    }       

}

Training time: 2.3648


In [10]:
for subset in ['Train', 'Test']:
    for f, force in enumerate(['Fx', 'Fy']):
        for loss in ['MAE', 'MSE', 'R2']:
            scores = [results[subset][loss][i + f] for i in range(0, len(CELLS) * 2, 2)]
            print(' '.join([subset, force, loss]) + ': {:.4f} ± {:.4f}'.format(np.mean(scores), np.std(scores)))

Train Fx MAE: 3.3318 ± 0.1608
Train Fx MSE: 21.3503 ± 0.9115
Train Fx R2: 0.9000 ± 0.0047
Train Fy MAE: 3.5355 ± 0.5565
Train Fy MSE: 29.5699 ± 12.6971
Train Fy R2: 0.5147 ± 0.0562
Test Fx MAE: 4.4937 ± 0.0224
Test Fx MSE: 32.7052 ± 0.4460
Test Fx R2: 0.7988 ± 0.0228
Test Fy MAE: 3.3093 ± 0.2047
Test Fy MSE: 22.0289 ± 0.0768
Test Fy R2: 0.4748 ± 0.0397


In [None]:
# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:, 3], Y_train[:, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:, 3], train_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_train[:100, 3], Y_train[:100, 4], label='true', alpha=0.3)
# plt.scatter(train_preds[:100, 3], train_preds[:100, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()

# plt.figure(figsize=(20,15))
# plt.scatter(Y_test[:, 3], Y_test[:, 4], label='true', alpha=0.3)
# plt.scatter(test_preds[:, 3], test_preds[:, 4], label='preds', alpha=0.3)
# plt.legend()
# plt.show()