In [91]:
from pathlib import Path
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from knee_stress_predict.config import raw_data_dir, processed_data_dir
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import RandomForestRegressor

## Load data


In [103]:
data_set_name = "set_3"
train_file_path = Path.joinpath(processed_data_dir, data_set_name,  "train_cleaned.csv")
test_file_path = Path.joinpath(processed_data_dir, data_set_name,  "test_cleaned.csv")
train_data = pd.read_csv(train_file_path).drop('Unnamed: 0', axis=1)
test_data = pd.read_csv(test_file_path).drop('Unnamed: 0', axis=1)

In [104]:
train_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 30279 entries, 0 to 30278
Data columns (total 46 columns):
 #   Column                         Non-Null Count  Dtype  
---  ------                         --------------  -----  
 0   Code                           30279 non-null  object 
 1   Patella_PN                     30279 non-null  int64  
 2   Femur_PN                       30279 non-null  int64  
 3   Tibia_PN                       30279 non-null  int64  
 4   Patella_Car_PN                 30279 non-null  int64  
 5   Femur_Car_PN                   30279 non-null  int64  
 6   Tibia_M_Car_PN                 30279 non-null  int64  
 7   Tibia_L_Car_PN                 30279 non-null  int64  
 8   Patella_volume                 30279 non-null  float64
 9   Femur_volume                   30279 non-null  float64
 10  Tibia_volume                   30279 non-null  float64
 11  Patella_Car_volume             30279 non-null  float64
 12  Femur_Car_volume               30279 non-null 

In [105]:
def get_tframe_dataset(tframe_num, dataset):
    tframe_dataset = dataset[dataset['frame'] == tframe_num]
    return tframe_dataset

In [106]:
tframe_nums = range(0, 240, 25)

#Create a dictionary that contains 10 timeframes

train_datasets = {}
for tframe_num in tframe_nums:
    train_datasets[tframe_num] = get_tframe_dataset(tframe_num, train_data)

test_datasets = {}
for tframe_num in tframe_nums:
    test_datasets[tframe_num] = get_tframe_dataset(tframe_num, test_data)


## Simple benchmark MSE

In [107]:
mse_simple = pd.DataFrame(columns=['frame', 'med_benchmark_MSE', 'lat_benchmark_MSE'])

for key, values in train_datasets.items():
    med_tibia_predicted = train_datasets[key]['Max_tib_med_contact_pressure'].mean()
    mse_med = ((test_datasets[key]['Max_tib_med_contact_pressure'] - med_tibia_predicted)**2).mean()

    lat_tibia_predicted = train_datasets[key]['Max_tib_lat_contact_pressure'].mean()
    mse_lat = ((test_datasets[key]['Max_tib_lat_contact_pressure'] - lat_tibia_predicted)**2).mean()
    mse_timeframe = { 'frame': key,  'med_benchmark_MSE': mse_med, 'lat_benchmark_MSE': mse_lat }

    mse_simple = mse_simple.append(mse_timeframe, ignore_index=True)

In [108]:
mse_simple

Unnamed: 0,frame,med_benchmark_MSE,lat_benchmark_MSE
0,0.0,3.636854,4.885563
1,25.0,29.75381,8.142631
2,50.0,222.903957,21.289709
3,75.0,60.357285,18.445215
4,100.0,92.651709,23.605557
5,125.0,77.510633,20.768768
6,150.0,82.510877,18.547202
7,175.0,111.937007,12.345029
8,200.0,81.277607,9.835287
9,225.0,12.943481,3.406698


## Random Forest

In [109]:
mse_rf = pd.DataFrame(columns=['frame', 'med_rf_MSE', 'lat_rf_MSE'])

for key,value in train_datasets.items():
    out_col = ['Max_tib_med_contact_pressure', 'Max_tib_lat_contact_pressure']
    feat_cols = list(set(list(train_datasets[key].columns)) - set(out_col) - {'Code', 'frame'})
    train_x = train_datasets[key][feat_cols]
    train_y_med = train_datasets[key][['Max_tib_med_contact_pressure']]
    train_y_lat = train_datasets[key][['Max_tib_lat_contact_pressure']]

    test_x = test_datasets[key][feat_cols]
    test_y_med = test_datasets[key][['Max_tib_med_contact_pressure']]
    test_y_lat = test_datasets[key][['Max_tib_lat_contact_pressure']]


    rf_model_med = RandomForestRegressor()
    rf_model_med.fit(train_x, train_y_med)
    yhat_med = rf_model_med.predict(test_x)
    mse_med = mean_squared_error(yhat_med, test_y_med)


    rf_model_lat = RandomForestRegressor()
    rf_model_lat.fit(train_x, train_y_lat)
    yhat_lat = rf_model_lat.predict(test_x)
    mse_lat = mean_squared_error(yhat_lat, test_y_lat)

    mse_timeframe = { 'frame': key,  'med_rf_MSE': mse_med, 'lat_rf_MSE': mse_lat }

    mse_rf = mse_rf.append(mse_timeframe, ignore_index=True)

In [110]:
mse_rf

Unnamed: 0,frame,med_rf_MSE,lat_rf_MSE
0,0.0,3.811144,4.227729
1,25.0,48.612159,8.539264
2,50.0,426.042633,18.620912
3,75.0,242.205623,102.787584
4,100.0,168.137798,26.287569
5,125.0,165.212561,30.81843
6,150.0,485.741782,24.02866
7,175.0,344.258752,14.891539
8,200.0,203.810161,35.680066
9,225.0,445.700542,4.088089


In [111]:
mse = mse_simple.merge(mse_rf, on='frame')

In [112]:
mse

Unnamed: 0,frame,med_benchmark_MSE,lat_benchmark_MSE,med_rf_MSE,lat_rf_MSE
0,0.0,3.636854,4.885563,3.811144,4.227729
1,25.0,29.75381,8.142631,48.612159,8.539264
2,50.0,222.903957,21.289709,426.042633,18.620912
3,75.0,60.357285,18.445215,242.205623,102.787584
4,100.0,92.651709,23.605557,168.137798,26.287569
5,125.0,77.510633,20.768768,165.212561,30.81843
6,150.0,82.510877,18.547202,485.741782,24.02866
7,175.0,111.937007,12.345029,344.258752,14.891539
8,200.0,81.277607,9.835287,203.810161,35.680066
9,225.0,12.943481,3.406698,445.700542,4.088089
