In [28]:
from pathlib import Path
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from knee_stress_predict.config import raw_data_dir, processed_data_dir

## Step 1: Load data


In [29]:
data_set_name = "set_3"
train_file_path = Path.joinpath(processed_data_dir, data_set_name,  "train_cleaned.csv")
test_file_path = Path.joinpath(processed_data_dir, data_set_name,  "test_cleaned.csv")
train_data = pd.read_csv(train_file_path).drop('Unnamed: 0', axis=1)
test_data = pd.read_csv(test_file_path).drop('Unnamed: 0', axis=1)

In [30]:
train_data

Unnamed: 0,Code,Patella_PN,Femur_PN,Tibia_PN,Patella_Car_PN,Femur_Car_PN,Tibia_M_Car_PN,Tibia_L_Car_PN,Patella_volume,Femur_volume,...,Simulation_len,Max_dist_femur_tibia_lat_car,Min_dist_femur_tibia_lat_car,Mean_dist_femur_tibia_lat_car,Max_dist_femur_tibia_med_car,Min_dist_femur_tibia_med_car,Mean_dist_femur_tibia_med_car,frame,Max_tib_med_contact_pressure,Max_tib_lat_contact_pressure
0,9003406M00,1137,4142,2789,5275,21420,2840,2840,33682.157434,231437.991665,...,0,13.941570,0.987391,6.105886,10.304892,1.376902,4.580906,0,7.619495,6.696390
1,9003406M00,1137,4142,2789,5275,21420,2840,2840,33682.157434,231437.991665,...,0,13.941570,0.987391,6.105886,10.304892,1.376902,4.580906,1,8.068417,5.042103
2,9003406M00,1137,4142,2789,5275,21420,2840,2840,33682.157434,231437.991665,...,0,13.941570,0.987391,6.105886,10.304892,1.376902,4.580906,2,7.796326,5.344840
3,9003406M00,1137,4142,2789,5275,21420,2840,2840,33682.157434,231437.991665,...,0,13.941570,0.987391,6.105886,10.304892,1.376902,4.580906,3,8.058777,6.240049
4,9003406M00,1137,4142,2789,5275,21420,2840,2840,33682.157434,231437.991665,...,0,13.941570,0.987391,6.105886,10.304892,1.376902,4.580906,4,8.098284,7.306662
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30274,9993846M12,873,3921,2831,5690,21495,2840,2840,24799.135262,226161.340527,...,1,18.862473,2.616759,6.815766,15.943200,2.833736,6.482209,236,11.455851,7.974414
30275,9993846M12,873,3921,2831,5690,21495,2840,2840,24799.135262,226161.340527,...,1,18.862473,2.616759,6.815766,15.943200,2.833736,6.482209,237,11.642740,8.043670
30276,9993846M12,873,3921,2831,5690,21495,2840,2840,24799.135262,226161.340527,...,1,18.862473,2.616759,6.815766,15.943200,2.833736,6.482209,238,12.039758,8.359842
30277,9993846M12,873,3921,2831,5690,21495,2840,2840,24799.135262,226161.340527,...,1,18.862473,2.616759,6.815766,15.943200,2.833736,6.482209,239,11.383085,8.170403


In [31]:
def get_tframe_dataset(tframe_num, dataset):
    tframe_dataset = dataset[dataset['frame'] == tframe_num]
    return tframe_dataset

In [32]:
tframe_nums = range(0, 240, 25)

#Create a dictionary that contains 10 timeframes
datasets = {}
for tframe_num in tframe_nums:
    datasets[tframe_num] = get_tframe_dataset(tframe_num, train_data)


In [34]:
datasets[0]

Unnamed: 0,Code,Patella_PN,Femur_PN,Tibia_PN,Patella_Car_PN,Femur_Car_PN,Tibia_M_Car_PN,Tibia_L_Car_PN,Patella_volume,Femur_volume,...,Simulation_len,Max_dist_femur_tibia_lat_car,Min_dist_femur_tibia_lat_car,Mean_dist_femur_tibia_lat_car,Max_dist_femur_tibia_med_car,Min_dist_femur_tibia_med_car,Mean_dist_femur_tibia_med_car,frame,Max_tib_med_contact_pressure,Max_tib_lat_contact_pressure
0,9003406M00,1137,4142,2789,5275,21420,2840,2840,33682.157434,231437.991665,...,0,13.941570,0.987391,6.105886,10.304892,1.376902,4.580906,0,7.619495,6.696390
241,9003406M12,1126,4113,2684,5900,21030,2840,2840,33039.218053,228671.171460,...,1,13.560157,0.800195,6.226970,10.677269,1.029139,4.306716,0,6.554834,7.020112
482,9007827M00,771,3414,2596,4485,18385,2840,2840,21736.919532,184294.924991,...,0,12.892850,1.504442,4.844820,9.243396,1.418712,4.015682,0,4.956423,6.670447
723,9007827M12,730,3404,2600,5055,18755,2840,2840,20866.993760,184114.045604,...,1,11.617199,1.119438,4.698536,9.107268,1.398252,3.890993,0,5.840032,5.090442
964,9040390M00,670,3907,2560,4935,19150,2840,2840,17988.202111,226082.515840,...,0,14.107578,2.342441,6.161425,10.770141,2.071546,5.517447,0,7.856375,6.560999
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29074,9993650M00,633,3084,2205,4255,20260,2185,2840,15168.028148,146698.667898,...,0,14.732158,1.686715,5.659945,13.039994,1.308433,4.848711,0,11.855247,10.523536
29315,9993833M00,795,3669,2724,4765,20410,2840,2840,22375.020348,207318.185631,...,0,16.633343,1.874496,6.641354,15.426130,2.258161,4.898385,0,7.753703,7.483450
29556,9993833M12,801,3805,2575,4885,20400,2840,2840,22439.863013,215968.865103,...,1,17.640301,1.481010,6.875417,12.561378,2.246910,4.471387,0,6.899568,4.661014
29797,9993846M00,868,3794,3029,5430,20755,2840,2840,25084.417057,216622.182978,...,0,19.743281,2.204936,6.574413,16.267729,1.759706,5.525146,0,7.215711,6.326750
