In [1]:
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split

In [2]:
csv_file_names = ["pyradiomics_extraction_box_with_correct_mask.csv", "pyradiomics_extraction_box_without_correct_mask.csv", 
             "pyradiomics_extraction_segmentation_maskcorrect.csv", "pyradiomics_extraction_segmentation_no_maskcorrect.csv"]

FILE_INDEX = 2
FILENAME = '../Data/Without Demographic Features/' + csv_file_names[FILE_INDEX]
CLASS_LABELS = '../Data/Patient class labels.csv'

randgen = 12345678
train_size = 0.6
test_size = 0.2
val_size = 0.2

In [3]:
df = pd.read_csv(FILENAME).drop(columns = 'sequence', errors='ignore')
labels = pd.read_csv(CLASS_LABELS)
total_features = pd.merge(df, labels, left_on = 'patient', right_on = 'Patient ID').drop(columns = ['Patient ID', 'patient'])
total_features

Unnamed: 0,original_shape_Elongation,original_shape_Flatness,original_shape_LeastAxisLength,original_shape_MajorAxisLength,original_shape_Maximum2DDiameterColumn,original_shape_Maximum2DDiameterRow,original_shape_Maximum2DDiameterSlice,original_shape_Maximum3DDiameter,original_shape_MeshVolume,original_shape_MinorAxisLength,...,original_glszm_ZoneVariance,original_ngtdm_Busyness,original_ngtdm_Coarseness,original_ngtdm_Complexity,original_ngtdm_Contrast,original_ngtdm_Strength,ER,PR,HER2,Mol Subtype
0,0.807005,0.729780,23.614309,32.358102,36.073737,38.431136,39.309346,45.798426,11352.720347,26.113156,...,4.009263e+06,0.0,1000000.0,0.0,0.0,0.0,0,0,1,2
1,0.641558,0.577887,17.378564,30.072592,23.148662,29.852205,29.748492,33.784677,4051.005400,19.293327,...,6.054147e+06,0.0,1000000.0,0.0,0.0,0.0,0,0,0,3
2,0.756223,0.272616,13.050754,47.872295,40.488107,43.126056,56.606869,57.665740,11866.396187,36.202150,...,1.663605e+07,0.0,1000000.0,0.0,0.0,0.0,1,1,0,0
3,0.718139,0.610317,8.730117,14.304230,11.420813,14.972284,14.477716,16.540987,829.333325,10.272429,...,0.000000e+00,0.0,1000000.0,0.0,0.0,0.0,1,1,0,0
4,0.687760,0.436019,22.539193,51.693090,55.175581,49.674566,36.443449,58.274413,25572.814941,35.552459,...,4.436710e+08,0.0,1000000.0,0.0,0.0,0.0,1,0,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
917,0.886914,0.668151,28.553043,42.734408,45.886501,41.977361,51.383808,56.016573,28567.316541,37.901723,...,0.000000e+00,0.0,1000000.0,0.0,0.0,0.0,1,1,0,0
918,0.757419,0.596730,12.864613,21.558522,21.066007,20.543932,25.375225,27.268354,2558.161146,16.328835,...,4.485585e+06,0.0,1000000.0,0.0,0.0,0.0,1,1,0,0
919,0.872632,0.837806,15.199560,18.142097,19.038279,20.649190,20.505122,24.348411,2671.696974,15.831368,...,0.000000e+00,0.0,1000000.0,0.0,0.0,0.0,1,1,0,0
920,0.855874,0.735409,30.453436,41.410193,43.076382,45.448460,48.143552,54.398662,28568.303573,35.441921,...,1.504081e+08,0.0,1000000.0,0.0,0.0,0.0,1,1,0,0


In [43]:
def train_test_val_split(df, train_ratio, val_ratio, stratify_class_name = 'ER', random_state = 2454259):
   
    val_ratio_adj = val_ratio / (1-train_ratio)

    stratification = df[stratify_class_name]
    train_df, val_df = train_test_split(df, train_size = train_ratio, stratify = stratification, random_state= random_state)
    stratification = val_df[stratify_class_name]
    val_df, test_df = train_test_split(val_df, train_size = val_ratio_adj, stratify = stratification, random_state= random_state)

    return train_df, val_df, test_df

In [44]:
# split for er

er_train, er_val, er_test = train_test_val_split(total_features.drop(columns = ['PR', 'HER2', 'Mol Subtype']), train_size, val_size, 'ER', random_state=randgen)
er_train = er_train.rename(columns = {'ER': 'label'})
er_test = er_test.rename(columns = {'ER': 'label'})
er_val = er_val.rename(columns = {'ER': 'label'})

print(er_train.shape, er_val.shape, er_test.shape)
er_train

(553, 108) (184, 108) (185, 108)


Unnamed: 0,original_shape_Elongation,original_shape_Flatness,original_shape_LeastAxisLength,original_shape_MajorAxisLength,original_shape_Maximum2DDiameterColumn,original_shape_Maximum2DDiameterRow,original_shape_Maximum2DDiameterSlice,original_shape_Maximum3DDiameter,original_shape_MeshVolume,original_shape_MinorAxisLength,...,original_glszm_SmallAreaLowGrayLevelEmphasis,original_glszm_ZoneEntropy,original_glszm_ZonePercentage,original_glszm_ZoneVariance,original_ngtdm_Busyness,original_ngtdm_Coarseness,original_ngtdm_Complexity,original_ngtdm_Contrast,original_ngtdm_Strength,label
614,0.943156,0.788066,23.312858,29.582374,32.939410,34.807445,32.411310,40.699790,10580.530688,27.900783,...,6.026235e-01,1.879965e+00,0.000458,3.803180e+07,0.0,1000000.0,0.0,0.0,0.0,1
431,0.940443,0.820170,15.900204,19.386472,21.389522,21.925923,23.414392,27.244395,3390.894471,18.231875,...,6.666667e-01,9.182958e-01,0.000435,1.053711e+07,0.0,1000000.0,0.0,0.0,0.0,1
550,0.757800,0.686727,14.948235,21.767371,19.069064,23.642477,23.190269,26.743910,3461.648541,16.495304,...,4.458375e-08,-3.203427e-16,0.000211,0.000000e+00,0.0,1000000.0,0.0,0.0,0.0,1
314,0.821802,0.361187,33.726028,93.375432,83.740345,109.558095,85.319249,115.062760,105376.516745,76.736147,...,4.832436e-01,2.800958e+00,0.001069,1.607501e+08,0.0,1000000.0,0.0,0.0,0.0,0
907,0.886801,0.808821,16.927384,20.928472,22.092000,23.768697,23.705430,28.231320,4170.036681,18.559388,...,5.000000e-01,1.000000e+00,0.000164,3.714293e+07,0.0,1000000.0,0.0,0.0,0.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
341,0.828628,0.411786,12.616856,30.639334,24.174221,28.382503,34.016599,35.456015,6375.656869,25.388604,...,8.182434e-09,-3.203427e-16,0.000090,0.000000e+00,0.0,1000000.0,0.0,0.0,0.0,1
142,0.804420,0.624621,9.148921,14.647156,13.197095,14.619172,16.293664,17.813576,736.698720,11.782469,...,1.262467e-06,-3.203427e-16,0.001124,0.000000e+00,0.0,1000000.0,0.0,0.0,0.0,1
15,0.863853,0.804463,52.482372,65.239043,76.690792,84.567373,85.369789,99.191729,87076.836228,56.356941,...,4.931760e-01,2.791384e+00,0.004592,4.288448e+07,0.0,1000000.0,0.0,0.0,0.0,1
882,0.925940,0.694403,24.011136,34.578106,36.393266,34.398698,40.974263,45.507474,17037.308159,32.017258,...,1.757428e-09,-3.203427e-16,0.000042,0.000000e+00,0.0,1000000.0,0.0,0.0,0.0,1


In [45]:
# split for pr

pr_train, pr_val, pr_test = train_test_val_split(total_features.drop(columns = ['ER', 'HER2', 'Mol Subtype']), train_size, val_size, 'PR', random_state=randgen)
pr_train = pr_train.rename(columns = {'PR': 'label'})
pr_test = pr_test.rename(columns = {'PR': 'label'})
pr_val = pr_val.rename(columns = {'PR': 'label'})

In [47]:
# split for HER2

her2_train, her2_val, her2_test = train_test_val_split(total_features.drop(columns = ['ER', 'PR', 'Mol Subtype']), train_size, val_size, 'HER2', random_state=randgen)

her2_train = her2_train.rename(columns = {'HER2': 'label'})
her2_test = her2_test.rename(columns = {'HER2': 'label'})
her2_val = her2_val.rename(columns = {'HER2': 'label'})

In [49]:
# split for molecular subtype

molsub_train, molsub_val, molsub_test = train_test_val_split(total_features.drop(columns = ['ER', 'PR', 'HER2']), train_size, val_size, 'Mol Subtype', random_state=randgen)

molsub_train = molsub_train.rename(columns = {'Mol Subtype': 'label'})
molsub_test = molsub_test.rename(columns = {'Mol Subtype': 'label'})
molsub_val = molsub_val.rename(columns = {'Mol Subtype': 'label'})

In [52]:
# saving
path = "New Split/"
os.makedirs(path+"ER", mode = 777, exist_ok = True)
os.makedirs(path+"PR", mode = 777, exist_ok = True)
os.makedirs(path+"HER2", mode = 777, exist_ok = True)
os.makedirs(path+"Mol_Subtype", mode = 777, exist_ok = True)


er_train.to_csv(path+"ER/train.csv", index = False)
pr_train.to_csv(path+'PR/train.csv', index = False)
her2_train.to_csv(path+'HER2/train.csv', index = False)
molsub_train.to_csv(path+"Mol_Subtype/train.csv", index = False)


er_test.to_csv(path+"ER/test.csv", index = False)
pr_test.to_csv(path+'PR/test.csv', index = False)
her2_test.to_csv(path+'HER2/test.csv', index = False)
molsub_test.to_csv(path+"Mol_Subtype/test.csv", index = False)


er_val.to_csv(path+"ER/val.csv", index = False)
pr_val.to_csv(path+'PR/val.csv', index = False)
her2_val.to_csv(path+'HER2/val.csv', index = False)
molsub_val.to_csv(path+"Mol_Subtype/val.csv", index = False)