In [29]:
import random
import pandas as pd
import numpy as np
import seaborn as sns
import networkx as nx
import matplotlib.pyplot as plt
%matplotlib inline

In [30]:
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, MinMaxScaler
from fancyimpute import KNN, IterativeImputer, SimpleFill, SoftImpute, IterativeSVD, MatrixFactorization, NuclearNormMinimization, BiScaler

In [31]:
SEED = 1
random.seed(SEED)
np.random.seed(SEED)

## 1. Dataset
- training : TrainingWiDS2021.csv
- test : UnlabeledWiDS2021.csv
- descriptions : DataDictionaryWiDS2021.csv

In [32]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [33]:
df_tr = pd.read_csv('/content/drive/MyDrive/dataset/WiDS2021/TrainingWiDS2021.csv')
df_tr = df_tr.drop(columns=['Unnamed: 0'], inplace=False)
df_tr.shape

(130157, 180)

In [34]:
df_te = pd.read_csv('/content/drive/MyDrive/dataset/WiDS2021/UnlabeledWiDS2021.csv')
df_te = df_te.drop(columns=['Unnamed: 0'], inplace=False)
df_te.shape

(10234, 179)

### 1.1. Drop Columns with high missing ratio

- check missing ratio for each column
- compare missing ratios between training set and test set
- Decise what to drop

In [35]:
def check_missing_data(df):
    # check missing data
    missing_df = pd.DataFrame(df.isna().sum().sort_values(ascending=False)).reset_index()
    missing_df.columns = ['column_name', 'num_miss_rows']
    missing_df['miss_ratio'] = missing_df.num_miss_rows / df.shape[0]

    print(missing_df.loc[missing_df.num_miss_rows > 0])

    print(missing_df.loc[missing_df.miss_ratio > 0.5])
    return missing_df

In [36]:
tr_missing = check_missing_data(df_tr)

          column_name  num_miss_rows  miss_ratio
0    h1_bilirubin_min         119861    0.920896
1    h1_bilirubin_max         119861    0.920896
2      h1_albumin_min         119005    0.914319
3      h1_albumin_max         119005    0.914319
4      h1_lactate_max         118467    0.910185
..                ...            ...         ...
155      d1_sysbp_max            271    0.002082
156  d1_heartrate_max            262    0.002013
157  d1_heartrate_min            262    0.002013
158  icu_admit_source            240    0.001844
159            gender             66    0.000507

[160 rows x 3 columns]
         column_name  num_miss_rows  miss_ratio
0   h1_bilirubin_min         119861    0.920896
1   h1_bilirubin_max         119861    0.920896
2     h1_albumin_min         119005    0.914319
3     h1_albumin_max         119005    0.914319
4     h1_lactate_max         118467    0.910185
..               ...            ...         ...
68  d1_bilirubin_max          76735    0.589557
69  

In [37]:
te_missing = check_missing_data(df_te)

          column_name  num_miss_rows  miss_ratio
0      h1_lactate_max           9421    0.920559
1      h1_lactate_min           9421    0.920559
2    h1_bilirubin_max           9407    0.919191
3    h1_bilirubin_min           9407    0.919191
4      h1_albumin_min           9365    0.915087
..                ...            ...         ...
154     d1_diasbp_max             23    0.002247
155      d1_sysbp_min             23    0.002247
156      d1_sysbp_max             23    0.002247
157     d1_diasbp_min             23    0.002247
158            gender              5    0.000489

[159 rows x 3 columns]
           column_name  num_miss_rows  miss_ratio
0       h1_lactate_max           9421    0.920559
1       h1_lactate_min           9421    0.920559
2     h1_bilirubin_max           9407    0.919191
3     h1_bilirubin_min           9407    0.919191
4       h1_albumin_min           9365    0.915087
..                 ...            ...         ...
69    d1_bilirubin_min           5860 

In [38]:
# train and test set with same missing ratios?
set(tr_missing.loc[tr_missing.miss_ratio > .5].column_name).\
difference(set(te_missing.loc[te_missing.miss_ratio > .5].column_name))

set(te_missing.loc[te_missing.miss_ratio > .5].column_name).\
difference(set(tr_missing.loc[tr_missing.miss_ratio > .5].column_name))

tr_missing.loc[tr_missing.column_name=='urineoutput_apache']
te_missing.loc[tr_missing.column_name=='urineoutput_apache']

Unnamed: 0,column_name,num_miss_rows,miss_ratio
73,urineoutput_apache,5190,0.507133


In [39]:
# since missing ratios are very similar => drop all columns with missing ratio upto 50%
drop_columns = te_missing.loc[te_missing.miss_ratio > .5].column_name.values
df_tr = df_tr.drop(columns = drop_columns, inplace=False)
df_te = df_te.drop(columns = drop_columns, inplace=False)

df_tr.shape
df_te.shape

(10234, 105)

### drop hospital_id

-due to distribution difference

In [40]:
df_tr = df_tr.drop(columns=['hospital_id'], inplace=False)
df_te = df_te.drop(columns=['hospital_id'], inplace=False)

### Readmission status have 1 unique value for all dataset => drop

In [41]:
df_tr.readmission_status.nunique()
tr_missing.loc[tr_missing.column_name=='readmission_status']
df_tr.readmission_status.unique()
df_te.readmission_status.unique()

array([0])

In [42]:
df_tr = df_tr.drop(columns=['readmission_status'], inplace=False)
df_te = df_te.drop(columns=['readmission_status'], inplace=False)

## Combine two

In [43]:
set(df_tr.columns).difference(set(df_te.columns))
set(df_te.columns).difference(set(df_tr.columns))

set()

In [44]:
df_te['diabetes_mellitus'] = np.nan
df_tr['split_type'] = 'train'
df_te['split_type'] = 'test'

In [45]:
df_t = pd.concat([df_tr, df_te])
df_t.columns
df_t.shape
df_t.head()

Unnamed: 0,encounter_id,age,bmi,elective_surgery,ethnicity,gender,height,hospital_admit_source,icu_admit_source,icu_id,icu_stay_type,icu_type,pre_icu_los_days,weight,apache_2_diagnosis,apache_3j_diagnosis,apache_post_operative,arf_apache,bun_apache,creatinine_apache,gcs_eyes_apache,gcs_motor_apache,gcs_unable_apache,gcs_verbal_apache,glucose_apache,heart_rate_apache,hematocrit_apache,intubated_apache,map_apache,resprate_apache,sodium_apache,temp_apache,ventilated_apache,wbc_apache,d1_diasbp_max,d1_diasbp_min,d1_diasbp_noninvasive_max,d1_diasbp_noninvasive_min,d1_heartrate_max,d1_heartrate_min,...,h1_resprate_min,h1_spo2_max,h1_spo2_min,h1_sysbp_max,h1_sysbp_min,h1_sysbp_noninvasive_max,h1_sysbp_noninvasive_min,h1_temp_max,h1_temp_min,d1_bun_max,d1_bun_min,d1_calcium_max,d1_calcium_min,d1_creatinine_max,d1_creatinine_min,d1_glucose_max,d1_glucose_min,d1_hco3_max,d1_hco3_min,d1_hemaglobin_max,d1_hemaglobin_min,d1_hematocrit_max,d1_hematocrit_min,d1_platelets_max,d1_platelets_min,d1_potassium_max,d1_potassium_min,d1_sodium_max,d1_sodium_min,d1_wbc_max,d1_wbc_min,aids,cirrhosis,hepatic_failure,immunosuppression,leukemia,lymphoma,solid_tumor_with_metastasis,diabetes_mellitus,split_type
0,214826,68.0,22.732803,0,Caucasian,M,180.3,Floor,Floor,92,admit,CTICU,0.541667,73.9,113.0,502.01,0,0,31.0,2.51,3.0,6.0,0.0,4.0,168.0,118.0,27.4,0,40.0,36.0,134.0,39.3,0,14.1,68.0,37.0,68.0,37.0,119.0,72.0,...,18.0,100.0,74.0,131.0,115.0,131.0,115.0,39.5,37.5,31.0,30.0,8.5,7.4,2.51,2.23,168.0,109.0,19.0,15.0,8.9,8.9,27.4,27.4,233.0,233.0,4.0,3.4,136.0,134.0,14.1,14.1,0,0,0,0,0,0,0,1.0,train
1,246060,77.0,27.421875,0,Caucasian,F,160.0,Floor,Floor,90,admit,Med-Surg ICU,0.927778,70.2,108.0,203.01,0,0,9.0,0.56,1.0,3.0,0.0,1.0,145.0,120.0,36.9,0,46.0,33.0,145.0,35.1,1,12.7,95.0,31.0,95.0,31.0,118.0,72.0,...,28.0,95.0,70.0,95.0,71.0,95.0,71.0,36.3,36.3,11.0,9.0,8.6,8.0,0.71,0.56,145.0,128.0,27.0,26.0,11.3,11.1,36.9,36.1,557.0,487.0,4.2,3.8,145.0,145.0,23.3,12.7,0,0,0,0,0,0,0,1.0,train
2,276985,25.0,31.952749,0,Caucasian,F,172.7,Emergency Department,Accident & Emergency,93,admit,Med-Surg ICU,0.000694,95.3,122.0,703.03,0,0,,,3.0,6.0,0.0,5.0,,102.0,,0,68.0,37.0,,36.7,0,,88.0,48.0,88.0,48.0,96.0,68.0,...,16.0,98.0,91.0,148.0,124.0,148.0,124.0,36.7,36.7,,,,,,,,,,,,,,,,,,,,,,,0,0,0,0,0,0,0,0.0,train
3,262220,81.0,22.635548,1,Caucasian,F,165.1,Operating Room,Operating Room / Recovery,92,admit,CTICU,0.000694,61.7,203.0,1206.03,1,0,,,4.0,6.0,0.0,5.0,185.0,114.0,25.9,1,60.0,4.0,,34.8,1,8.0,48.0,42.0,48.0,42.0,116.0,92.0,...,11.0,100.0,99.0,136.0,106.0,,,35.6,34.8,,,,,,,185.0,88.0,,,11.6,8.9,34.0,25.9,198.0,43.0,5.0,3.5,,,9.0,8.0,0,0,0,0,0,0,0,0.0,train
4,201746,19.0,,0,Caucasian,M,188.0,,Accident & Emergency,91,admit,Med-Surg ICU,0.073611,,119.0,601.01,0,0,,,,,,,,60.0,,0,103.0,16.0,,36.7,0,,99.0,57.0,99.0,57.0,89.0,60.0,...,,100.0,100.0,130.0,120.0,130.0,120.0,,,,,,,,,,,,,,,,,,,,,,,,,0,0,0,0,0,0,0,0.0,train


## Categorical encoding
- Label Encoding : assign label to a unique integer
- OneHot Encoding : creating dummy variables

In [46]:
df_t.dtypes.unique()
df_t.dtypes.loc[df_tr.dtypes=='O']

ethnicity                object
gender                   object
hospital_admit_source    object
icu_admit_source         object
icu_stay_type            object
icu_type                 object
split_type               object
dtype: object

In [47]:
cat_cols = list(df_t.dtypes.loc[df_t.dtypes=='O'].index.values)
cat_cols.append('apache_2_diagnosis')
cat_cols.append('apache_3j_diagnosis')
cat_cols.remove('split_type')
print(cat_cols)

['ethnicity', 'gender', 'hospital_admit_source', 'icu_admit_source', 'icu_stay_type', 'icu_type', 'apache_2_diagnosis', 'apache_3j_diagnosis']


In [48]:
df_t = pd.get_dummies(df_t, prefix=cat_cols, columns=cat_cols)

In [49]:
df_t.head()

Unnamed: 0,encounter_id,age,bmi,elective_surgery,height,icu_id,pre_icu_los_days,weight,apache_post_operative,arf_apache,bun_apache,creatinine_apache,gcs_eyes_apache,gcs_motor_apache,gcs_unable_apache,gcs_verbal_apache,glucose_apache,heart_rate_apache,hematocrit_apache,intubated_apache,map_apache,resprate_apache,sodium_apache,temp_apache,ventilated_apache,wbc_apache,d1_diasbp_max,d1_diasbp_min,d1_diasbp_noninvasive_max,d1_diasbp_noninvasive_min,d1_heartrate_max,d1_heartrate_min,d1_mbp_max,d1_mbp_min,d1_mbp_noninvasive_max,d1_mbp_noninvasive_min,d1_resprate_max,d1_resprate_min,d1_spo2_max,d1_spo2_min,...,apache_3j_diagnosis_1605.01,apache_3j_diagnosis_1701.01,apache_3j_diagnosis_1701.02,apache_3j_diagnosis_1701.03,apache_3j_diagnosis_1701.04,apache_3j_diagnosis_1703.01,apache_3j_diagnosis_1703.02,apache_3j_diagnosis_1703.03,apache_3j_diagnosis_1703.04,apache_3j_diagnosis_1703.05,apache_3j_diagnosis_1703.06,apache_3j_diagnosis_1703.07,apache_3j_diagnosis_1703.08,apache_3j_diagnosis_1704.01,apache_3j_diagnosis_1705.02,apache_3j_diagnosis_1705.03,apache_3j_diagnosis_1705.04,apache_3j_diagnosis_1705.05,apache_3j_diagnosis_1801.01,apache_3j_diagnosis_1801.02,apache_3j_diagnosis_1802.01,apache_3j_diagnosis_1802.02,apache_3j_diagnosis_1803.01,apache_3j_diagnosis_1803.02,apache_3j_diagnosis_1902.01,apache_3j_diagnosis_1902.02,apache_3j_diagnosis_1902.03,apache_3j_diagnosis_1902.04,apache_3j_diagnosis_1902.05,apache_3j_diagnosis_1903.01,apache_3j_diagnosis_1903.02,apache_3j_diagnosis_1903.03,apache_3j_diagnosis_1904.01,apache_3j_diagnosis_2101.01,apache_3j_diagnosis_2101.03,apache_3j_diagnosis_2201.01,apache_3j_diagnosis_2201.02,apache_3j_diagnosis_2201.03,apache_3j_diagnosis_2201.04,apache_3j_diagnosis_2201.05
0,214826,68.0,22.732803,0,180.3,92,0.541667,73.9,0,0,31.0,2.51,3.0,6.0,0.0,4.0,168.0,118.0,27.4,0,40.0,36.0,134.0,39.3,0,14.1,68.0,37.0,68.0,37.0,119.0,72.0,89.0,46.0,89.0,46.0,34.0,10.0,100.0,74.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,246060,77.0,27.421875,0,160.0,90,0.927778,70.2,0,0,9.0,0.56,1.0,3.0,0.0,1.0,145.0,120.0,36.9,0,46.0,33.0,145.0,35.1,1,12.7,95.0,31.0,95.0,31.0,118.0,72.0,120.0,38.0,120.0,38.0,32.0,12.0,100.0,70.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,276985,25.0,31.952749,0,172.7,93,0.000694,95.3,0,0,,,3.0,6.0,0.0,5.0,,102.0,,0,68.0,37.0,,36.7,0,,88.0,48.0,88.0,48.0,96.0,68.0,102.0,68.0,102.0,68.0,21.0,8.0,98.0,91.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,262220,81.0,22.635548,1,165.1,92,0.000694,61.7,1,0,,,4.0,6.0,0.0,5.0,185.0,114.0,25.9,1,60.0,4.0,,34.8,1,8.0,48.0,42.0,48.0,42.0,116.0,92.0,84.0,84.0,84.0,84.0,23.0,7.0,100.0,95.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,201746,19.0,,0,188.0,91,0.073611,,0,0,,,,,,,,60.0,,0,103.0,16.0,,36.7,0,,99.0,57.0,99.0,57.0,89.0,60.0,104.0,90.0,104.0,90.0,18.0,16.0,100.0,96.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [50]:
df_t.columns

Index(['encounter_id', 'age', 'bmi', 'elective_surgery', 'height', 'icu_id',
       'pre_icu_los_days', 'weight', 'apache_post_operative', 'arf_apache',
       ...
       'apache_3j_diagnosis_1903.02', 'apache_3j_diagnosis_1903.03',
       'apache_3j_diagnosis_1904.01', 'apache_3j_diagnosis_2101.01',
       'apache_3j_diagnosis_2101.03', 'apache_3j_diagnosis_2201.01',
       'apache_3j_diagnosis_2201.02', 'apache_3j_diagnosis_2201.03',
       'apache_3j_diagnosis_2201.04', 'apache_3j_diagnosis_2201.05'],
      dtype='object', length=582)

### 2. Data imputation

- Possible approaches : mean, KNN, soft_impute, MICE, iterative_SVD

In [51]:
# TODO apply normalized imputation?
# SimpleFill, SoftImpute, IterativeSVD, MatrixFactorization, NuclearNormMinimization, BiScaler
def impute_data(df_t, impt_type):
    if impt_type =='mice':
        imputer = IterativeImputer()
    elif impt_type == 'knn':
        imputer = KNN(orientation='columns')
    elif impt_type == 'mean':
        imputer = SimpleFill("mean")
    elif impt_type == 'soft_impute':
        imputer = SoftImpute()

    return imputer.fit_transform(df_t)

In [52]:
list(filter(lambda x: x.find('split')>=0, df_t.columns))

['split_type']

In [53]:
impute_cols = list(df_t.columns.values)
impute_cols.remove('diabetes_mellitus')
impute_cols.remove('encounter_id')
impute_cols.remove('split_type')

impt_t = impute_data(df_t[impute_cols], 'mean')

In [54]:
df_impt_t = pd.DataFrame(impt_t)
df_impt_t.columns = impute_cols
df_impt_t.head()

Unnamed: 0,age,bmi,elective_surgery,height,icu_id,pre_icu_los_days,weight,apache_post_operative,arf_apache,bun_apache,creatinine_apache,gcs_eyes_apache,gcs_motor_apache,gcs_unable_apache,gcs_verbal_apache,glucose_apache,heart_rate_apache,hematocrit_apache,intubated_apache,map_apache,resprate_apache,sodium_apache,temp_apache,ventilated_apache,wbc_apache,d1_diasbp_max,d1_diasbp_min,d1_diasbp_noninvasive_max,d1_diasbp_noninvasive_min,d1_heartrate_max,d1_heartrate_min,d1_mbp_max,d1_mbp_min,d1_mbp_noninvasive_max,d1_mbp_noninvasive_min,d1_resprate_max,d1_resprate_min,d1_spo2_max,d1_spo2_min,d1_sysbp_max,...,apache_3j_diagnosis_1605.01,apache_3j_diagnosis_1701.01,apache_3j_diagnosis_1701.02,apache_3j_diagnosis_1701.03,apache_3j_diagnosis_1701.04,apache_3j_diagnosis_1703.01,apache_3j_diagnosis_1703.02,apache_3j_diagnosis_1703.03,apache_3j_diagnosis_1703.04,apache_3j_diagnosis_1703.05,apache_3j_diagnosis_1703.06,apache_3j_diagnosis_1703.07,apache_3j_diagnosis_1703.08,apache_3j_diagnosis_1704.01,apache_3j_diagnosis_1705.02,apache_3j_diagnosis_1705.03,apache_3j_diagnosis_1705.04,apache_3j_diagnosis_1705.05,apache_3j_diagnosis_1801.01,apache_3j_diagnosis_1801.02,apache_3j_diagnosis_1802.01,apache_3j_diagnosis_1802.02,apache_3j_diagnosis_1803.01,apache_3j_diagnosis_1803.02,apache_3j_diagnosis_1902.01,apache_3j_diagnosis_1902.02,apache_3j_diagnosis_1902.03,apache_3j_diagnosis_1902.04,apache_3j_diagnosis_1902.05,apache_3j_diagnosis_1903.01,apache_3j_diagnosis_1903.02,apache_3j_diagnosis_1903.03,apache_3j_diagnosis_1904.01,apache_3j_diagnosis_2101.01,apache_3j_diagnosis_2101.03,apache_3j_diagnosis_2201.01,apache_3j_diagnosis_2201.02,apache_3j_diagnosis_2201.03,apache_3j_diagnosis_2201.04,apache_3j_diagnosis_2201.05
0,68.0,22.732803,0.0,180.3,92.0,0.541667,73.9,0.0,0.0,31.0,2.51,3.0,6.0,0.0,4.0,168.0,118.0,27.4,0.0,40.0,36.0,134.0,39.3,0.0,14.1,68.0,37.0,68.0,37.0,119.0,72.0,89.0,46.0,89.0,46.0,34.0,10.0,100.0,74.0,131.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,77.0,27.421875,0.0,160.0,90.0,0.927778,70.2,0.0,0.0,9.0,0.56,1.0,3.0,0.0,1.0,145.0,120.0,36.9,0.0,46.0,33.0,145.0,35.1,1.0,12.7,95.0,31.0,95.0,31.0,118.0,72.0,120.0,38.0,120.0,38.0,32.0,12.0,100.0,70.0,159.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,25.0,31.952749,0.0,172.7,93.0,0.000694,95.3,0.0,0.0,25.674075,1.475558,3.0,6.0,0.0,5.0,160.232596,102.0,32.969455,0.0,68.0,37.0,137.950038,36.7,0.0,12.175376,88.0,48.0,88.0,48.0,96.0,68.0,102.0,68.0,102.0,68.0,21.0,8.0,98.0,91.0,148.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,81.0,22.635548,1.0,165.1,92.0,0.000694,61.7,1.0,0.0,25.674075,1.475558,4.0,6.0,0.0,5.0,185.0,114.0,25.9,1.0,60.0,4.0,137.950038,34.8,1.0,8.0,48.0,42.0,48.0,42.0,116.0,92.0,84.0,84.0,84.0,84.0,23.0,7.0,100.0,95.0,158.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,19.0,29.110683,0.0,188.0,91.0,0.073611,83.769687,0.0,0.0,25.674075,1.475558,3.488934,5.485565,0.011639,4.032544,160.232596,60.0,32.969455,0.0,103.0,16.0,137.950038,36.7,0.0,12.175376,99.0,57.0,99.0,57.0,89.0,60.0,104.0,90.0,104.0,90.0,18.0,16.0,100.0,96.0,147.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [55]:
df_impt_t.shape
len(impute_cols)

579

In [56]:
df_impt_t['encounter_id'] = df_t.encounter_id.values
df_impt_t['diabetes_mellitus'] = df_t.diabetes_mellitus.values
df_impt_t['split_type'] = df_t.split_type.values
df_impt_t = df_impt_t.reset_index(inplace=False, drop=True)
df_impt_t.head()

Unnamed: 0,age,bmi,elective_surgery,height,icu_id,pre_icu_los_days,weight,apache_post_operative,arf_apache,bun_apache,creatinine_apache,gcs_eyes_apache,gcs_motor_apache,gcs_unable_apache,gcs_verbal_apache,glucose_apache,heart_rate_apache,hematocrit_apache,intubated_apache,map_apache,resprate_apache,sodium_apache,temp_apache,ventilated_apache,wbc_apache,d1_diasbp_max,d1_diasbp_min,d1_diasbp_noninvasive_max,d1_diasbp_noninvasive_min,d1_heartrate_max,d1_heartrate_min,d1_mbp_max,d1_mbp_min,d1_mbp_noninvasive_max,d1_mbp_noninvasive_min,d1_resprate_max,d1_resprate_min,d1_spo2_max,d1_spo2_min,d1_sysbp_max,...,apache_3j_diagnosis_1701.03,apache_3j_diagnosis_1701.04,apache_3j_diagnosis_1703.01,apache_3j_diagnosis_1703.02,apache_3j_diagnosis_1703.03,apache_3j_diagnosis_1703.04,apache_3j_diagnosis_1703.05,apache_3j_diagnosis_1703.06,apache_3j_diagnosis_1703.07,apache_3j_diagnosis_1703.08,apache_3j_diagnosis_1704.01,apache_3j_diagnosis_1705.02,apache_3j_diagnosis_1705.03,apache_3j_diagnosis_1705.04,apache_3j_diagnosis_1705.05,apache_3j_diagnosis_1801.01,apache_3j_diagnosis_1801.02,apache_3j_diagnosis_1802.01,apache_3j_diagnosis_1802.02,apache_3j_diagnosis_1803.01,apache_3j_diagnosis_1803.02,apache_3j_diagnosis_1902.01,apache_3j_diagnosis_1902.02,apache_3j_diagnosis_1902.03,apache_3j_diagnosis_1902.04,apache_3j_diagnosis_1902.05,apache_3j_diagnosis_1903.01,apache_3j_diagnosis_1903.02,apache_3j_diagnosis_1903.03,apache_3j_diagnosis_1904.01,apache_3j_diagnosis_2101.01,apache_3j_diagnosis_2101.03,apache_3j_diagnosis_2201.01,apache_3j_diagnosis_2201.02,apache_3j_diagnosis_2201.03,apache_3j_diagnosis_2201.04,apache_3j_diagnosis_2201.05,encounter_id,diabetes_mellitus,split_type
0,68.0,22.732803,0.0,180.3,92.0,0.541667,73.9,0.0,0.0,31.0,2.51,3.0,6.0,0.0,4.0,168.0,118.0,27.4,0.0,40.0,36.0,134.0,39.3,0.0,14.1,68.0,37.0,68.0,37.0,119.0,72.0,89.0,46.0,89.0,46.0,34.0,10.0,100.0,74.0,131.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,214826,1.0,train
1,77.0,27.421875,0.0,160.0,90.0,0.927778,70.2,0.0,0.0,9.0,0.56,1.0,3.0,0.0,1.0,145.0,120.0,36.9,0.0,46.0,33.0,145.0,35.1,1.0,12.7,95.0,31.0,95.0,31.0,118.0,72.0,120.0,38.0,120.0,38.0,32.0,12.0,100.0,70.0,159.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,246060,1.0,train
2,25.0,31.952749,0.0,172.7,93.0,0.000694,95.3,0.0,0.0,25.674075,1.475558,3.0,6.0,0.0,5.0,160.232596,102.0,32.969455,0.0,68.0,37.0,137.950038,36.7,0.0,12.175376,88.0,48.0,88.0,48.0,96.0,68.0,102.0,68.0,102.0,68.0,21.0,8.0,98.0,91.0,148.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,276985,0.0,train
3,81.0,22.635548,1.0,165.1,92.0,0.000694,61.7,1.0,0.0,25.674075,1.475558,4.0,6.0,0.0,5.0,185.0,114.0,25.9,1.0,60.0,4.0,137.950038,34.8,1.0,8.0,48.0,42.0,48.0,42.0,116.0,92.0,84.0,84.0,84.0,84.0,23.0,7.0,100.0,95.0,158.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,262220,0.0,train
4,19.0,29.110683,0.0,188.0,91.0,0.073611,83.769687,0.0,0.0,25.674075,1.475558,3.488934,5.485565,0.011639,4.032544,160.232596,60.0,32.969455,0.0,103.0,16.0,137.950038,36.7,0.0,12.175376,99.0,57.0,99.0,57.0,89.0,60.0,104.0,90.0,104.0,90.0,18.0,16.0,100.0,96.0,147.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,201746,0.0,train


In [57]:
cols = list(df_impt_t.columns)
cols.remove('split_type')

In [58]:
list(filter(lambda x: x=='split_type', df_impt_t.columns))

['split_type']

In [59]:
tr = df_impt_t.loc[df_impt_t['split_type']=='train']
tr = tr.drop(columns='split_type', inplace=False)

te = df_impt_t.loc[df_impt_t['split_type']=='test']
te = te.drop(columns=['split_type', 'diabetes_mellitus'], inplace=False)

In [60]:
tr.to_parquet('/content/drive/MyDrive/dataset/dummy_noscale_train.parquet')
te.to_parquet('/content/drive/MyDrive/dataset/dummy_noscale_test.parquet')

### 3. Scaling

In [61]:
# TODO : different scalers?
def scale_data(mx_t, scl_type='minmax'):
    if scl_type == 'minmax':
        scaler = MinMaxScaler()
    return scaler.fit_transform(mx_t)

In [62]:
cols = list(df_impt_t.columns)
for cat in ['diabetes_mellitus', 'ethnicity', 'gender', 'hospital_admit_source', 'icu_admit_source', 'icu_stay_type', 
            'icu_type', 'apache_2_diagnosis', 'apache_3j_diagnosis', 'encounter_id']:
    relevent_cols = list(filter(lambda x: x.find(cat)>=0, cols))
    print(relevent_cols)
    if len(relevent_cols) > 0:
        for r in relevent_cols:
            cols.remove(r)
df_impt_t[cols]

['diabetes_mellitus']
['ethnicity_African American', 'ethnicity_Asian', 'ethnicity_Caucasian', 'ethnicity_Hispanic', 'ethnicity_Native American', 'ethnicity_Other/Unknown']
['gender_F', 'gender_M']
['hospital_admit_source_Acute Care/Floor', 'hospital_admit_source_Chest Pain Center', 'hospital_admit_source_Direct Admit', 'hospital_admit_source_Emergency Department', 'hospital_admit_source_Floor', 'hospital_admit_source_ICU', 'hospital_admit_source_ICU to SDU', 'hospital_admit_source_Observation', 'hospital_admit_source_Operating Room', 'hospital_admit_source_Other', 'hospital_admit_source_Other Hospital', 'hospital_admit_source_Other ICU', 'hospital_admit_source_PACU', 'hospital_admit_source_Recovery Room', 'hospital_admit_source_Step-Down Unit (SDU)']
['icu_admit_source_Accident & Emergency', 'icu_admit_source_Floor', 'icu_admit_source_Operating Room / Recovery', 'icu_admit_source_Other Hospital', 'icu_admit_source_Other ICU']
['icu_stay_type_admit', 'icu_stay_type_readmit', 'icu_stay_

Unnamed: 0,age,bmi,elective_surgery,height,icu_id,pre_icu_los_days,weight,apache_post_operative,arf_apache,bun_apache,creatinine_apache,gcs_eyes_apache,gcs_motor_apache,gcs_unable_apache,gcs_verbal_apache,glucose_apache,heart_rate_apache,hematocrit_apache,intubated_apache,map_apache,resprate_apache,sodium_apache,temp_apache,ventilated_apache,wbc_apache,d1_diasbp_max,d1_diasbp_min,d1_diasbp_noninvasive_max,d1_diasbp_noninvasive_min,d1_heartrate_max,d1_heartrate_min,d1_mbp_max,d1_mbp_min,d1_mbp_noninvasive_max,d1_mbp_noninvasive_min,d1_resprate_max,d1_resprate_min,d1_spo2_max,d1_spo2_min,d1_sysbp_max,...,h1_resprate_max,h1_resprate_min,h1_spo2_max,h1_spo2_min,h1_sysbp_max,h1_sysbp_min,h1_sysbp_noninvasive_max,h1_sysbp_noninvasive_min,h1_temp_max,h1_temp_min,d1_bun_max,d1_bun_min,d1_calcium_max,d1_calcium_min,d1_creatinine_max,d1_creatinine_min,d1_glucose_max,d1_glucose_min,d1_hco3_max,d1_hco3_min,d1_hemaglobin_max,d1_hemaglobin_min,d1_hematocrit_max,d1_hematocrit_min,d1_platelets_max,d1_platelets_min,d1_potassium_max,d1_potassium_min,d1_sodium_max,d1_sodium_min,d1_wbc_max,d1_wbc_min,aids,cirrhosis,hepatic_failure,immunosuppression,leukemia,lymphoma,solid_tumor_with_metastasis,split_type
0,68.0,22.732803,0.0,180.3,92.0,0.541667,73.900000,0.0,0.0,31.000000,2.510000,3.000000,6.000000,0.000000,4.000000,168.000000,118.0,27.400000,0.0,40.0,36.0,134.000000,39.3,0.0,14.100000,68.0,37.0,68.0,37.0,119.0,72.0,89.0,46.0,89.0,46.0,34.0,10.0,100.0,74.0,131.0,...,26.000000,18.00000,100.000000,74.00000,131.000000,115.000000,131.000000,115.000000,39.500000,37.500000,31.000000,30.000000,8.500000,7.400000,2.510000,2.230000,168.00000,109.000000,19.000000,15.000000,8.900000,8.900000,27.400000,27.400000,233.000000,233.000000,4.000000,3.400000,136.00000,134.000000,14.100000,14.100000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,train
1,77.0,27.421875,0.0,160.0,90.0,0.927778,70.200000,0.0,0.0,9.000000,0.560000,1.000000,3.000000,0.000000,1.000000,145.000000,120.0,36.900000,0.0,46.0,33.0,145.000000,35.1,1.0,12.700000,95.0,31.0,95.0,31.0,118.0,72.0,120.0,38.0,120.0,38.0,32.0,12.0,100.0,70.0,159.0,...,31.000000,28.00000,95.000000,70.00000,95.000000,71.000000,95.000000,71.000000,36.300000,36.300000,11.000000,9.000000,8.600000,8.000000,0.710000,0.560000,145.00000,128.000000,27.000000,26.000000,11.300000,11.100000,36.900000,36.100000,557.000000,487.000000,4.200000,3.800000,145.00000,145.000000,23.300000,12.700000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,train
2,25.0,31.952749,0.0,172.7,93.0,0.000694,95.300000,0.0,0.0,25.674075,1.475558,3.000000,6.000000,0.000000,5.000000,160.232596,102.0,32.969455,0.0,68.0,37.0,137.950038,36.7,0.0,12.175376,88.0,48.0,88.0,48.0,96.0,68.0,102.0,68.0,102.0,68.0,21.0,8.0,98.0,91.0,148.0,...,20.000000,16.00000,98.000000,91.00000,148.000000,124.000000,148.000000,124.000000,36.700000,36.700000,25.535824,23.515146,8.373893,8.158161,1.487727,1.358696,174.16197,114.509667,24.462418,23.203704,11.463989,10.902546,34.562535,32.931879,205.646722,194.621195,4.252062,3.927516,139.15532,137.693919,12.530771,11.284294,0.0,0.0,0.0,0.0,0.0,0.0,0.0,train
3,81.0,22.635548,1.0,165.1,92.0,0.000694,61.700000,1.0,0.0,25.674075,1.475558,4.000000,6.000000,0.000000,5.000000,185.000000,114.0,25.900000,1.0,60.0,4.0,137.950038,34.8,1.0,8.000000,48.0,42.0,48.0,42.0,116.0,92.0,84.0,84.0,84.0,84.0,23.0,7.0,100.0,95.0,158.0,...,12.000000,11.00000,100.000000,99.00000,136.000000,106.000000,132.973246,116.072897,35.600000,34.800000,25.535824,23.515146,8.373893,8.158161,1.487727,1.358696,185.00000,88.000000,24.462418,23.203704,11.600000,8.900000,34.000000,25.900000,198.000000,43.000000,5.000000,3.500000,139.15532,137.693919,9.000000,8.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,train
4,19.0,29.110683,0.0,188.0,91.0,0.073611,83.769687,0.0,0.0,25.674075,1.475558,3.488934,5.485565,0.011639,4.032544,160.232596,60.0,32.969455,0.0,103.0,16.0,137.950038,36.7,0.0,12.175376,99.0,57.0,99.0,57.0,89.0,60.0,104.0,90.0,104.0,90.0,18.0,16.0,100.0,96.0,147.0,...,22.517258,17.07467,100.000000,100.00000,130.000000,120.000000,130.000000,120.000000,36.721634,36.613362,25.535824,23.515146,8.373893,8.158161,1.487727,1.358696,174.16197,114.509667,24.462418,23.203704,11.463989,10.902546,34.562535,32.931879,205.646722,194.621195,4.252062,3.927516,139.15532,137.693919,12.530771,11.284294,0.0,0.0,0.0,0.0,0.0,0.0,0.0,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
140386,36.0,37.500000,0.0,170.1,1108.0,1.696528,108.600000,0.0,0.0,25.674075,1.475558,3.000000,6.000000,0.000000,5.000000,160.232596,111.0,29.000000,0.0,127.0,45.0,137.950038,36.5,0.0,7.200000,98.0,68.0,98.0,68.0,93.0,50.0,127.0,92.0,127.0,92.0,73.0,8.0,100.0,97.0,173.0,...,17.000000,17.00000,98.000000,98.00000,132.000000,128.000000,132.000000,128.000000,36.500000,36.500000,45.000000,45.000000,8.300000,8.300000,3.230000,3.230000,96.00000,96.000000,21.000000,21.000000,10.100000,9.500000,30.000000,27.000000,170.000000,144.000000,5.000000,5.000000,137.00000,137.000000,11.700000,7.200000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,test
140387,61.0,32.100000,0.0,160.0,1108.0,0.033333,82.300000,0.0,0.0,33.000000,1.150000,4.000000,6.000000,0.000000,5.000000,94.000000,106.0,27.000000,0.0,166.0,49.0,139.000000,36.7,0.0,11.200000,116.0,56.0,116.0,56.0,103.0,66.0,166.0,81.0,166.0,81.0,49.0,9.0,100.0,95.0,227.6,...,22.517258,17.07467,98.096228,95.26395,133.187031,115.945046,132.973246,116.072897,36.721634,36.613362,33.000000,33.000000,7.900000,7.900000,1.150000,1.150000,94.00000,94.000000,24.000000,24.000000,9.400000,8.900000,30.000000,27.000000,228.000000,228.000000,3.600000,3.600000,139.00000,139.000000,11.200000,11.200000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,test
140388,74.0,22.700000,0.0,165.1,1108.0,0.757639,62.000000,0.0,0.0,25.674075,1.475558,4.000000,6.000000,0.000000,5.000000,160.232596,47.0,32.969455,0.0,113.0,41.0,137.950038,36.6,0.0,12.175376,82.0,49.0,82.0,49.0,77.0,47.0,118.0,67.0,118.0,67.0,37.0,10.0,100.0,93.0,185.0,...,20.000000,17.00000,96.000000,93.00000,112.000000,102.000000,112.000000,102.000000,36.721634,36.613362,23.000000,23.000000,7.400000,7.400000,1.090000,1.090000,150.00000,150.000000,26.000000,26.000000,9.900000,9.900000,30.000000,30.000000,87.000000,87.000000,4.300000,4.300000,141.00000,141.000000,5.500000,5.500000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,test
140389,90.0,19.900000,0.0,160.0,1108.0,0.087500,50.900000,0.0,0.0,25.674075,1.475558,4.000000,6.000000,0.000000,4.000000,160.232596,94.0,32.969455,0.0,104.0,54.0,137.950038,36.1,0.0,12.175376,70.0,57.0,70.0,57.0,129.0,61.0,99.0,74.0,99.0,74.0,48.0,18.0,100.0,96.0,150.0,...,26.000000,24.00000,100.000000,100.00000,136.000000,136.000000,136.000000,136.000000,36.100000,36.100000,16.000000,16.000000,8.600000,8.600000,0.920000,0.920000,98.00000,98.000000,40.000000,39.000000,11.900000,11.900000,37.000000,37.000000,297.000000,297.000000,3.700000,3.700000,139.00000,139.000000,6.400000,6.400000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,test


In [63]:
#print(cols)
cols.remove('split_type')
sc_impt_t = scale_data(df_impt_t[cols], 'minmax')
df_impt_t[cols] = sc_impt_t

### 4. Save

In [64]:
s_tr = df_impt_t.loc[df_impt_t['split_type']=='train']
s_tr = s_tr.drop(columns='split_type', inplace=False)

s_te = df_impt_t.loc[df_impt_t['split_type']=='test']
s_te = s_te.drop(columns=['split_type', 'diabetes_mellitus'], inplace=False)

s_tr.to_parquet('/content/drive/MyDrive/dataset/train_scale_softimpute_.parquet')
s_te.to_parquet('/content/drive/MyDrive/dataset/test_scale_softimpute_.parquet')
#s_tr.to_parquet('../dataset/train_scale_.parquet')
#s_te.to_parquet('../dataset/test_scale_.parquet')