In [12]:
from pathlib import Path
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from knee_stress_predict.config import raw_data_dir, processed_data_dir
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import RandomForestRegressor

## Load data

In [13]:
data_set_name = "set_3"
train_file_path = Path.joinpath(processed_data_dir, data_set_name,  "train_cleaned.csv")
test_file_path = Path.joinpath(processed_data_dir, data_set_name,  "test_cleaned.csv")
train_data = pd.read_csv(train_file_path).drop('Unnamed: 0', axis=1)
test_data = pd.read_csv(test_file_path).drop('Unnamed: 0', axis=1)

In [14]:
train_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 30279 entries, 0 to 30278
Data columns (total 55 columns):
 #   Column                         Non-Null Count  Dtype  
---  ------                         --------------  -----  
 0   Code                           30279 non-null  object 
 1   Patella_PN                     30279 non-null  int64  
 2   Femur_PN                       30279 non-null  int64  
 3   Tibia_PN                       30279 non-null  int64  
 4   Patella_Car_PN                 30279 non-null  int64  
 5   Femur_Car_PN                   30279 non-null  int64  
 6   Tibia_M_Car_PN                 30279 non-null  int64  
 7   Tibia_L_Car_PN                 30279 non-null  int64  
 8   Patella_volume                 30279 non-null  float64
 9   Femur_volume                   30279 non-null  float64
 10  Tibia_volume                   30279 non-null  float64
 11  Patella_Car_volume             30279 non-null  float64
 12  Femur_Car_volume               30279 non-null 

In [15]:
def get_tframe_dataset(tframe_num, dataset):
    tframe_dataset = dataset[dataset['frame'] == tframe_num]
    return tframe_dataset

In [16]:
train_datasets = get_tframe_dataset(0, train_data)
test_datasets = get_tframe_dataset(0, test_data)

## Drop unnecessary columns

In [18]:
train_datasets = train_datasets.drop(["frame", "femur_left_gap_p", "femur_right_gap_p", "Code"], axis=1)
test_datasets = test_datasets.drop(["frame", "femur_left_gap_p", "femur_right_gap_p", "Code"], axis=1)

In [19]:
train_datasets

Unnamed: 0,Patella_PN,Femur_PN,Tibia_PN,Patella_Car_PN,Femur_Car_PN,Tibia_M_Car_PN,Tibia_L_Car_PN,Patella_volume,Femur_volume,Tibia_volume,...,Max_dist_femur_tibia_lat_car,Min_dist_femur_tibia_lat_car,Mean_dist_femur_tibia_lat_car,Max_dist_femur_tibia_med_car,Min_dist_femur_tibia_med_car,Mean_dist_femur_tibia_med_car,lat_Max_all_frames,med_Max_all_frames,Max_tib_med_contact_pressure,Max_tib_lat_contact_pressure
0,1137,4142,2789,5275,21420,2840,2840,33682.157434,231437.991665,139880.512171,...,13.941570,0.987391,6.105886,10.304892,1.376902,4.580906,15.890460,34.044975,7.619495,6.696390
241,1126,4113,2684,5900,21030,2840,2840,33039.218053,228671.171460,134963.134980,...,13.560157,0.800195,6.226970,10.677269,1.029139,4.306716,20.057634,31.649405,6.554834,7.020112
482,771,3414,2596,4485,18385,2840,2840,21736.919532,184294.924991,127637.611277,...,12.892850,1.504442,4.844820,9.243396,1.418712,4.015682,18.555532,26.223894,4.956423,6.670447
723,730,3404,2600,5055,18755,2840,2840,20866.993760,184114.045604,128295.783408,...,11.617199,1.119438,4.698536,9.107268,1.398252,3.890993,19.447876,24.808067,5.840032,5.090442
964,670,3907,2560,4935,19150,2840,2840,17988.202111,226082.515840,132402.041764,...,14.107578,2.342441,6.161425,10.770141,2.071546,5.517447,228.103546,137.528534,7.856375,6.560999
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29074,633,3084,2205,4255,20260,2185,2840,15168.028148,146698.667898,102326.881891,...,14.732158,1.686715,5.659945,13.039994,1.308433,4.848711,20.220541,36.707184,11.855247,10.523536
29315,795,3669,2724,4765,20410,2840,2840,22375.020348,207318.185631,146786.471170,...,16.633343,1.874496,6.641354,15.426130,2.258161,4.898385,19.478767,22.278835,7.753703,7.483450
29556,801,3805,2575,4885,20400,2840,2840,22439.863013,215968.865103,136753.285023,...,17.640301,1.481010,6.875417,12.561378,2.246910,4.471387,20.616781,26.929873,6.899568,4.661014
29797,868,3794,3029,5430,20755,2840,2840,25084.417057,216622.182978,168327.403757,...,19.743281,2.204936,6.574413,16.267729,1.759706,5.525146,17.599583,22.764488,7.215711,6.326750


In [20]:
test_datasets

Unnamed: 0,Patella_PN,Femur_PN,Tibia_PN,Patella_Car_PN,Femur_Car_PN,Tibia_M_Car_PN,Tibia_L_Car_PN,Patella_volume,Femur_volume,Tibia_volume,...,Max_dist_femur_tibia_lat_car,Min_dist_femur_tibia_lat_car,Mean_dist_femur_tibia_lat_car,Max_dist_femur_tibia_med_car,Min_dist_femur_tibia_med_car,Mean_dist_femur_tibia_med_car,lat_Max_all_frames,med_Max_all_frames,Max_tib_med_contact_pressure,Max_tib_lat_contact_pressure
0,788,4137,2689,5575,21035,2840,2840,22692.085291,248976.459037,146188.096235,...,13.779237,1.205597,4.27283,20.666032,2.178122,6.670601,24.228903,16.497473,5.602037,8.015358
241,527,3022,2229,3815,20220,2840,2840,10857.20256,146243.03644,102239.159975,...,21.284904,1.277765,6.60377,10.369182,1.176817,3.926511,20.081484,31.558796,14.084329,8.341026
482,971,3746,2799,4985,22785,2840,2840,30455.80484,217672.765536,155360.366258,...,16.570434,1.646447,5.989795,11.721987,1.394432,4.899639,26.539942,14.72912,7.665385,8.679437
723,867,3575,2604,4685,18085,2840,2840,26574.049046,207733.326164,139532.495131,...,19.327724,1.414403,7.044505,15.558528,2.203551,6.545311,30.914001,32.08202,7.569712,8.833647
964,796,3693,2820,4515,18505,2840,5160,21902.730915,206979.152648,150993.716861,...,14.922067,1.157918,5.545328,9.611753,1.956866,4.509022,31.127924,18.559311,7.339238,8.477967
1078,638,3334,2466,4685,20160,2315,2840,16611.521125,177709.783397,124081.197778,...,14.62971,1.216608,6.158827,9.355632,1.965896,4.348648,18.824017,28.095928,9.901465,9.094282
1319,725,3688,2568,5965,20005,2840,2840,19946.762534,204130.287956,137407.126695,...,12.480014,2.688065,6.360378,9.319604,1.602352,5.082682,22.74605,36.272587,7.42036,7.687375
1560,682,3536,2522,4165,19195,2355,2840,18909.762405,204468.276728,137445.470948,...,11.23042,1.152456,4.57115,11.410034,0.823154,3.888859,20.282326,20.356743,7.172037,11.942789
1801,653,3451,2169,4405,20275,2840,2840,15296.663201,173967.102959,99355.880702,...,14.342613,2.279677,5.25586,10.066185,0.754882,3.6424,22.52924,32.366993,11.36982,8.53199
2042,613,3064,2074,4095,18825,2840,2840,15630.041565,156580.112121,101463.017837,...,12.991571,1.120302,4.794382,12.351678,1.012572,4.638105,18.645138,33.404396,9.1162,8.827696
