In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

np.random.seed(1337)
np.set_printoptions(suppress=True)
pd.set_option('display.max_columns', None)  # or 1000
pd.set_option('display.max_rows', None)  # or 1000

# Data

In [2]:
# load the training data from the respective filepaths
train_features_filepath = 'data/train_features.csv'
test_features_filepath = 'data/test_features.csv'
train_labels_filepath = 'data/train_labels.csv'

# create the dataframes
train_features_df = pd.read_csv(train_features_filepath, float_precision="round_trip")
test_features_df = pd.read_csv(test_features_filepath, float_precision="round_trip")
train_labels_df = pd.read_csv(train_labels_filepath, float_precision="round_trip")
print("Train shape:", train_features_df.shape, "| Train label shape:", train_labels_df.shape, "| Test shape:", test_features_df.shape)
display(train_features_df.head(40))
display(train_labels_df.head(40))
display(test_features_df.head(30))

display(train_features_df.describe())
display(train_labels_df.describe())

Train shape: (227940, 37) | Train label shape: (18995, 16) | Test shape: (151968, 37)


Unnamed: 0,pid,Time,Age,EtCO2,PTT,BUN,Lactate,Temp,Hgb,HCO3,BaseExcess,RRate,Fibrinogen,Phosphate,WBC,Creatinine,PaCO2,AST,FiO2,Platelets,SaO2,Glucose,ABPm,Magnesium,Potassium,ABPd,Calcium,Alkalinephos,SpO2,Bilirubin_direct,Chloride,Hct,Heartrate,Bilirubin_total,TroponinI,ABPs,pH
0,1,3,34.0,,,12.0,,36.0,8.7,24.0,-2.0,16.0,,,6.3,,45.0,,,,,,84.0,1.2,3.8,61.0,,,100.0,,114.0,24.6,94.0,,,142.0,7.33
1,1,4,34.0,,,,,36.0,,,-2.0,16.0,,,,,,,0.5,,,,81.0,,,62.5,,,100.0,,,,99.0,,,125.0,7.33
2,1,5,34.0,,,,,36.0,,,0.0,18.0,,,,,43.0,,0.4,,,,74.0,,,59.0,,,100.0,,,,92.0,,,110.0,7.37
3,1,6,34.0,,,,,37.0,,,0.0,18.0,,,,,,,,,,,66.0,,,49.5,,,100.0,,,,88.0,,,104.0,7.37
4,1,7,34.0,,,,,,,,,18.0,,,,,,,,,,,63.0,1.8,,48.0,,,100.0,,,22.4,81.0,,,100.0,7.41
5,1,8,34.0,,,,,37.0,,,,16.0,,,,,,,0.4,,,,68.0,1.8,,51.0,,,100.0,,,22.4,82.0,,,106.0,
6,1,9,34.0,,,,,37.0,,,,18.0,,,,,,,,,,,65.0,,,46.0,,,100.0,,,,67.0,,,112.0,
7,1,10,34.0,,,,,37.0,,,,18.0,,,,,,,,,,,68.0,,,47.0,,,100.0,,,,62.0,,,121.0,
8,1,11,34.0,,,12.0,,,8.5,26.0,,12.0,,4.6,4.7,0.5,,,,143.0,,120.0,67.0,2.1,4.1,47.0,7.6,,100.0,,111.0,23.8,58.0,,,118.0,
9,1,12,34.0,,,12.0,,38.0,8.5,26.0,0.0,18.0,,,4.7,,42.0,,0.4,,,,62.0,2.1,4.1,44.0,,,100.0,,111.0,23.8,66.0,,,110.0,7.39


Unnamed: 0,pid,LABEL_BaseExcess,LABEL_Fibrinogen,LABEL_AST,LABEL_Alkalinephos,LABEL_Bilirubin_total,LABEL_Lactate,LABEL_TroponinI,LABEL_SaO2,LABEL_Bilirubin_direct,LABEL_EtCO2,LABEL_Sepsis,LABEL_RRate,LABEL_ABPm,LABEL_SpO2,LABEL_Heartrate
0,1,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,12.1,85.4,100.0,59.9
1,10,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,17.8,100.6,95.5,85.5
2,100,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,16.5,88.3,96.5,108.1
3,1000,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,19.4,77.2,98.3,80.9
4,10000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,12.6,76.8,97.7,95.3
5,10002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,14.5,67.4,99.1,64.6
6,10006,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,17.2,84.9,96.8,90.5
7,10007,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,23.8,97.0,94.3,76.0
8,10009,1.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,18.9,72.0,95.3,91.5
9,1001,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,14.9,83.3,99.9,117.3


Unnamed: 0,pid,Time,Age,EtCO2,PTT,BUN,Lactate,Temp,Hgb,HCO3,BaseExcess,RRate,Fibrinogen,Phosphate,WBC,Creatinine,PaCO2,AST,FiO2,Platelets,SaO2,Glucose,ABPm,Magnesium,Potassium,ABPd,Calcium,Alkalinephos,SpO2,Bilirubin_direct,Chloride,Hct,Heartrate,Bilirubin_total,TroponinI,ABPs,pH
0,0,1,39.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
1,0,2,39.0,,44.2,17.0,,36.0,10.2,13.0,,,147.0,6.0,17.5,2.2,,32.0,0.6,194.0,,273.0,77.0,2.2,4.6,76.0,8.0,119.0,100.0,,98.0,31.0,82.0,21.8,,119.0,
2,0,3,39.0,,,,,,,,-9.0,13.0,,,,,26.0,,0.55,,,,78.0,,,72.5,,,100.0,,,,78.0,,,125.0,7.34
3,0,4,39.0,,,,,,,,,12.0,,,,,,,0.5,,,,87.0,,,66.0,,,100.0,,,,80.0,,,136.0,
4,0,5,39.0,,,,,,,,,,,,,,,,,,,,86.0,,,65.0,,,100.0,,,,83.0,,,135.0,
5,0,6,39.0,,,,,36.0,,,,10.0,,,,,,,,,,,89.0,,,66.0,,,100.0,,,,88.0,,,144.0,
6,0,7,39.0,,38.5,20.0,,,9.1,16.0,,12.0,,4.8,18.5,2.4,,31.0,0.5,193.0,,162.0,,2.5,3.7,,8.3,109.0,100.0,,102.0,25.9,,26.4,,,
7,0,8,39.0,,,,,36.0,,,-4.0,12.0,,,,,30.0,,0.5,,,,80.0,,,59.0,,,100.0,,,,90.0,,,129.0,7.4
8,0,9,39.0,,,,,36.0,,,,12.0,,,,,,,,,,,75.0,,,56.5,,,100.0,,,,90.0,,,121.0,
9,0,10,39.0,,,,,36.0,,,,11.0,,,,,,,,,,,74.0,,,55.0,,,100.0,,,,85.0,,,120.0,


Unnamed: 0,pid,Time,Age,EtCO2,PTT,BUN,Lactate,Temp,Hgb,HCO3,BaseExcess,RRate,Fibrinogen,Phosphate,WBC,Creatinine,PaCO2,AST,FiO2,Platelets,SaO2,Glucose,ABPm,Magnesium,Potassium,ABPd,Calcium,Alkalinephos,SpO2,Bilirubin_direct,Chloride,Hct,Heartrate,Bilirubin_total,TroponinI,ABPs,pH
count,227940.0,227940.0,227940.0,9783.0,10299.0,20105.0,10756.0,81115.0,22295.0,12559.0,19887.0,187785.0,2493.0,11590.0,19083.0,17792.0,21043.0,5761.0,26602.0,18035.0,13014.0,47036.0,195889.0,17523.0,28393.0,152418.0,17830.0,5708.0,195192.0,719.0,13917.0,27297.0,200128.0,5326.0,3776.0,191650.0,25046.0
mean,15788.831219,7.014399,62.073809,32.883114,40.09131,23.192664,2.859716,36.852136,10.628208,23.4881,-1.239284,18.154043,262.496911,3.612519,11.738649,1.495777,41.115696,193.444888,0.701666,204.666426,93.010527,142.169407,82.117276,2.004149,4.152729,64.014711,7.161149,97.796163,97.663449,1.390723,106.260185,31.28309,84.522371,1.640941,7.26924,122.369877,7.367231
std,9151.896286,4.716103,16.451854,7.802065,26.034961,20.024289,2.428368,0.875152,2.074859,4.40378,4.192677,5.037031,133.02091,1.384462,10.088872,1.898112,8.929873,682.836708,24.522126,104.156406,10.887271,56.89453,16.471871,0.437286,0.670168,13.920097,2.812067,122.773379,2.786186,2.792722,5.916082,5.770425,17.643437,3.244145,25.172442,23.273834,0.074384
min,1.0,1.0,15.0,10.0,12.5,1.0,0.2,21.0,3.3,0.0,-29.0,1.0,34.0,0.2,0.1,0.1,10.0,5.0,0.0,2.0,24.0,15.0,20.0,0.5,1.3,20.0,1.0,12.0,20.0,0.01,66.0,9.4,23.0,0.1,0.01,21.0,6.82
25%,7879.0,4.0,52.0,28.5,27.8,12.0,1.4,36.0,9.2,21.0,-3.0,15.0,177.0,2.8,7.6,0.7,36.0,21.0,0.4,137.0,95.0,108.0,71.0,1.7,3.7,54.25,7.3,53.0,96.0,0.1,103.0,27.3,72.0,0.5,0.03,105.0,7.33
50%,15726.0,7.0,64.0,33.0,32.2,17.0,2.1,37.0,10.5,23.9,-1.0,18.0,233.0,3.4,10.4,0.9,40.0,36.0,0.5,189.0,97.0,130.0,80.0,2.0,4.1,62.0,8.2,72.0,98.0,0.3,107.0,30.9,83.0,0.8,0.15,119.0,7.37
75%,23725.0,10.0,74.0,38.0,40.6,27.0,3.4,37.0,12.0,26.0,0.0,21.0,316.0,4.2,14.0,1.38,45.0,84.0,0.6,251.0,98.0,160.0,91.0,2.2,4.5,72.0,8.7,104.0,100.0,1.21,110.0,35.0,95.0,1.4,2.05,137.0,7.41
max,31658.0,315.0,100.0,100.0,250.0,268.0,31.0,42.0,23.8,50.0,100.0,97.0,1179.0,16.4,440.0,41.9,100.0,9961.0,4000.0,2322.0,100.0,952.0,300.0,9.6,10.75,298.0,20.6,3833.0,100.0,21.2,141.0,63.4,191.0,46.5,440.0,287.0,7.78


Unnamed: 0,pid,LABEL_BaseExcess,LABEL_Fibrinogen,LABEL_AST,LABEL_Alkalinephos,LABEL_Bilirubin_total,LABEL_Lactate,LABEL_TroponinI,LABEL_SaO2,LABEL_Bilirubin_direct,LABEL_EtCO2,LABEL_Sepsis,LABEL_RRate,LABEL_ABPm,LABEL_SpO2,LABEL_Heartrate
count,18995.0,18995.0,18995.0,18995.0,18995.0,18995.0,18995.0,18995.0,18995.0,18995.0,18995.0,18995.0,18995.0,18995.0,18995.0,18995.0
mean,15788.831219,0.268281,0.073704,0.239747,0.23622,0.24059,0.200211,0.099763,0.233693,0.033904,0.066017,0.057278,18.79596,82.511171,96.947311,84.119716
std,9152.117122,0.443076,0.261295,0.42694,0.42477,0.427453,0.400168,0.299692,0.42319,0.180986,0.248319,0.23238,3.511241,12.74511,2.110957,14.718396
min,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,26.0,27.0,30.2
25%,7879.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,16.55,73.2,95.9,73.7
50%,15726.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,18.4,81.0,97.1,83.4
75%,23724.5,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,20.6,90.2,98.3,93.6
max,31658.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,41.1,147.1,100.0,155.6


In [3]:
labels_sepsis_positive = train_labels_df.loc[train_labels_df['LABEL_Sepsis'] == 1]
display(labels_sepsis_positive.describe())
display(labels_sepsis_positive.head(100))

#temp_grouped = group_columns_in_df(train_features_df)
#display(temp_grouped.head(10))
#features_sepsis_positive = temp_grouped.loc[temp_grouped['pid'].isin(labels_sepsis_positive['pid'])]
#scale_df(features_sepsis_positive)
#display(features_sepsis_positive.describe())

Unnamed: 0,pid,LABEL_BaseExcess,LABEL_Fibrinogen,LABEL_AST,LABEL_Alkalinephos,LABEL_Bilirubin_total,LABEL_Lactate,LABEL_TroponinI,LABEL_SaO2,LABEL_Bilirubin_direct,LABEL_EtCO2,LABEL_Sepsis,LABEL_RRate,LABEL_ABPm,LABEL_SpO2,LABEL_Heartrate
count,1088.0,1088.0,1088.0,1088.0,1088.0,1088.0,1088.0,1088.0,1088.0,1088.0,1088.0,1088.0,1088.0,1088.0,1088.0,1088.0
mean,15935.926471,0.547794,0.17739,0.465993,0.460478,0.469669,0.605699,0.129596,0.511949,0.100184,0.189338,1.0,19.881541,80.647059,97.036397,88.546599
std,9166.983876,0.497939,0.382174,0.499072,0.498665,0.499309,0.488925,0.336012,0.500087,0.300383,0.391957,0.0,4.631534,11.236957,2.451813,14.966026
min,13.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,51.8,67.0,49.8
25%,8043.25,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,16.7,72.8,95.9,78.175
50%,16180.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,19.5,79.55,97.4,87.8
75%,23941.0,1.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,22.5,88.0,98.6,98.325
max,31515.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,36.9,125.3,100.0,134.7


Unnamed: 0,pid,LABEL_BaseExcess,LABEL_Fibrinogen,LABEL_AST,LABEL_Alkalinephos,LABEL_Bilirubin_total,LABEL_Lactate,LABEL_TroponinI,LABEL_SaO2,LABEL_Bilirubin_direct,LABEL_EtCO2,LABEL_Sepsis,LABEL_RRate,LABEL_ABPm,LABEL_SpO2,LABEL_Heartrate
7,10007,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,23.8,97.0,94.3,76.0
20,10023,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,13.3,80.3,96.7,78.0
25,10030,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,17.9,85.3,99.8,95.2
27,10034,1.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,16.1,89.2,99.7,65.4
31,10038,1.0,1.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,0.0,1.0,18.4,62.0,97.1,81.1
35,10048,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,18.4,95.5,97.6,98.6
53,10072,0.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,15.7,78.8,97.6,91.7
83,10116,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,15.8,91.2,99.3,74.9
93,10130,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,20.6,78.2,97.0,101.8
109,10150,0.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,0.0,1.0,18.1,72.8,97.9,67.0


## Grouping

* groupby the `'pid'` column values for each age range from `[0, 20, 30, 40, 50, 60, 70, 80, 110]`
* ignore `'Time'` col

In [4]:
def group_columns_in_df(df):
    age_ranges = [0, 20, 30, 40, 50, 60, 70, 80, 110]
    dfs = []
    for i in range(len(age_ranges) - 1):
        age_l = age_ranges[i]
        age_u = age_ranges[i+1]
        sub_df = df[(df['Age'] >= age_l) & (df['Age'] < age_u)]
        sub_df = sub_df.loc[:, sub_df.columns != 'Time'].groupby('pid', as_index=False, sort=False).agg(np.nanmean)
        sub_df = sub_df.fillna(sub_df.mean())
        dfs.append(sub_df)
    grouped_df = df.loc[:, df.columns != 'Time'].groupby('pid', as_index=False, sort=False).agg(np.nanmean)
    result = grouped_df.copy()
    result.update(result[['pid']].merge(pd.concat(dfs)))
    result = result.astype({"pid": int})
    assert(grouped_df.shape == result.shape)
    assert(not grouped_df.equals(result))
    return result

X_df = group_columns_in_df(train_features_df)
display(X_df.shape)
display(X_df.head(20))

X_test_df = group_columns_in_df(test_features_df)

(18995, 36)

Unnamed: 0,pid,Age,EtCO2,PTT,BUN,Lactate,Temp,Hgb,HCO3,BaseExcess,RRate,Fibrinogen,Phosphate,WBC,Creatinine,PaCO2,AST,FiO2,Platelets,SaO2,Glucose,ABPm,Magnesium,Potassium,ABPd,Calcium,Alkalinephos,SpO2,Bilirubin_direct,Chloride,Hct,Heartrate,Bilirubin_total,TroponinI,ABPs,pH
0,1,34.0,35.270246,36.296005,12.0,2.267844,36.75,8.566667,25.333333,-0.666667,17.0,282.406566,4.6,5.233333,0.5,43.333333,179.917422,0.425,143.0,95.156566,120.0,68.333333,1.8,4.0,50.25,7.6,101.242051,100.0,1.97375,112.0,23.2,77.083333,2.196677,3.712206,114.5,7.37
1,10,71.0,31.424716,27.8,12.0,2.395573,36.0,14.6,23.828877,-0.881882,18.090909,274.985885,2.5,11.5,0.82,40.970572,20.0,0.552612,207.0,92.402792,152.0,101.727273,1.5,3.2,83.272727,8.6,68.0,98.0,0.608505,106.167733,42.1,78.818182,1.3,0.01,132.909091,7.375165
2,100,68.0,32.960816,20.9,21.0,2.37691,36.25,12.5,27.0,-0.727915,14.833333,270.562388,3.5,12.5,1.1,41.495807,146.01825,1.069466,204.0,93.827695,243.0,81.833333,1.7,3.6,62.833333,9.0,90.714723,96.5,1.046905,101.0,36.8,109.083333,1.435074,6.996616,117.0,7.372458
3,1000,79.0,31.863636,39.371346,22.0,3.855,36.818182,9.2,23.828877,-0.881882,12.0,274.985885,1.9,19.6,0.96,44.0,86.03802,0.4,158.0,98.0,128.625,83.454545,2.0,3.966667,62.818182,3.463333,97.562013,98.818182,0.608505,106.167733,27.3,86.363636,1.197754,5.817339,141.909091,7.3
4,10000,76.0,31.424716,28.55,22.0,2.395573,36.75,10.7,25.5,1.5,12.090909,274.985885,3.609838,7.75,1.0,44.5,86.03802,0.5,135.0,98.25,121.75,69.090909,1.4,3.9,48.227273,7.70575,97.562013,98.545455,0.608505,103.5,30.3,77.090909,1.197754,5.817339,123.0,7.39
5,10002,73.0,19.0,31.3,18.0,3.005,37.0,10.4,23.828877,-0.881882,19.625,161.0,3.0,10.3,0.98,40.5,41.0,0.8,83.0,92.402792,127.166667,69.818182,2.1,4.475,48.5,3.12,38.0,99.181818,0.608505,107.0,30.3,67.090909,0.8,5.817339,132.090909,7.375
6,10006,51.0,34.020707,37.275808,21.317991,2.395311,37.5,10.949367,23.963837,-0.580284,18.888889,268.857981,3.661243,11.297389,1.438611,41.414587,176.596782,0.538963,212.857003,94.235683,200.5,70.555556,1.987571,4.112962,48.714286,7.742205,98.531877,96.666667,1.586346,105.757518,32.484703,82.0,1.816378,7.683001,117.888889,7.374335
7,10007,60.0,32.960816,39.756291,23.358213,2.37691,38.0,10.672379,24.035953,-0.727915,21.909091,270.562388,2.4,11.619513,1.50062,41.495807,146.01825,1.069466,209.292544,93.827695,87.0,108.181818,1.6,3.6,88.363636,7.524813,90.714723,94.909091,1.046905,105.700469,31.858225,79.909091,1.435074,0.08,139.363636,7.372458
8,10009,69.0,32.960816,86.05,15.0,2.37691,37.25,12.2,21.0,-0.727915,22.5,270.562388,3.1,8.7,0.5,41.495807,17.0,1.069466,182.0,93.827695,109.0,65.909091,1.2,4.1,65.227273,8.1,64.0,97.090909,1.046905,104.0,34.0,97.727273,0.8,6.996616,90.909091,7.372458
9,1001,36.0,35.270246,31.2,10.0,1.8,37.666667,10.4,31.0,7.5,13.363636,282.406566,2.1,8.7,0.5,40.25,179.917422,0.416667,205.0,95.156566,98.666667,85.909091,2.0,3.1,69.181818,8.5,101.242051,100.0,1.97375,105.0,29.8,106.727273,2.196677,3.712206,113.090909,7.515


## Adding features

In [5]:
def agg_fn(x):
    if x.isna().all():
        return 0
    else:
        min_val = np.nanmin(x)
        max_val = np.nanmax(x)
        res = np.abs(max_val - min_val)
        return res

def add_features(df_to_group, df_to_add_features_to):
    """Add maximum absolute difference features for each column to a dataframe
    
    Args:
        df_to_group (pd.DataFrame): the dataframe that will be grouped by 'pid'
        df_to_add_features_to (pd.DataFrame): the dataframe to add new features to
        
    Returns:
        pd.DataFrame: the new dataframe with all features
    """
    features_df = df_to_group.drop(['Age', 'Time'], axis=1).groupby('pid', as_index=False, sort=False).agg(agg_fn)
    features_df.columns = [str(col) + '_Diff' for col in features_df.columns]
    # Remove 'pid' col from features_df
    features_df = features_df.iloc[:, 1:]
    return pd.concat([df_to_add_features_to, features_df], axis=1, sort=False)

X_df = add_features(train_features_df, X_df)

display(X_df.shape)
display(X_df.head(30))

X_test_df = add_features(test_features_df, X_test_df)

(18995, 70)

Unnamed: 0,pid,Age,EtCO2,PTT,BUN,Lactate,Temp,Hgb,HCO3,BaseExcess,RRate,Fibrinogen,Phosphate,WBC,Creatinine,PaCO2,AST,FiO2,Platelets,SaO2,Glucose,ABPm,Magnesium,Potassium,ABPd,Calcium,Alkalinephos,SpO2,Bilirubin_direct,Chloride,Hct,Heartrate,Bilirubin_total,TroponinI,ABPs,pH,EtCO2_Diff,PTT_Diff,BUN_Diff,Lactate_Diff,Temp_Diff,Hgb_Diff,HCO3_Diff,BaseExcess_Diff,RRate_Diff,Fibrinogen_Diff,Phosphate_Diff,WBC_Diff,Creatinine_Diff,PaCO2_Diff,AST_Diff,FiO2_Diff,Platelets_Diff,SaO2_Diff,Glucose_Diff,ABPm_Diff,Magnesium_Diff,Potassium_Diff,ABPd_Diff,Calcium_Diff,Alkalinephos_Diff,SpO2_Diff,Bilirubin_direct_Diff,Chloride_Diff,Hct_Diff,Heartrate_Diff,Bilirubin_total_Diff,TroponinI_Diff,ABPs_Diff,pH_Diff
0,1,34.0,35.270246,36.296005,12.0,2.267844,36.75,8.566667,25.333333,-0.666667,17.0,282.406566,4.6,5.233333,0.5,43.333333,179.917422,0.425,143.0,95.156566,120.0,68.333333,1.8,4.0,50.25,7.6,101.242051,100.0,1.97375,112.0,23.2,77.083333,2.196677,3.712206,114.5,7.37,0.0,0.0,0.0,0.0,2.0,0.2,2.0,2.0,6.0,0.0,0.0,1.6,0.0,3.0,0.0,0.1,0.0,0.0,0.0,31.0,0.9,0.3,23.5,0.0,0.0,0.0,0.0,3.0,2.4,41.0,0.0,0.0,42.0,0.08
1,10,71.0,31.424716,27.8,12.0,2.395573,36.0,14.6,23.828877,-0.881882,18.090909,274.985885,2.5,11.5,0.82,40.970572,20.0,0.552612,207.0,92.402792,152.0,101.727273,1.5,3.2,83.272727,8.6,68.0,98.0,0.608505,106.167733,42.1,78.818182,1.3,0.01,132.909091,7.375165,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,37.0,0.0,0.0,35.0,0.0,0.0,3.0,0.0,0.0,0.0,19.0,0.0,0.0,39.0,0.0
2,100,68.0,32.960816,20.9,21.0,2.37691,36.25,12.5,27.0,-0.727915,14.833333,270.562388,3.5,12.5,1.1,41.495807,146.01825,1.069466,204.0,93.827695,243.0,81.833333,1.7,3.6,62.833333,9.0,90.714723,96.5,1.046905,101.0,36.8,109.083333,1.435074,6.996616,117.0,7.372458,0.0,0.0,0.0,0.0,3.0,0.0,0.0,0.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,32.0,0.0,0.0,20.0,0.0,0.0,8.0,0.0,0.0,0.0,37.0,0.0,0.0,61.0,0.0
3,1000,79.0,31.863636,39.371346,22.0,3.855,36.818182,9.2,23.828877,-0.881882,12.0,274.985885,1.9,19.6,0.96,44.0,86.03802,0.4,158.0,98.0,128.625,83.454545,2.0,3.966667,62.818182,3.463333,97.562013,98.818182,0.608505,106.167733,27.3,86.363636,1.197754,5.817339,141.909091,7.3,11.0,0.0,0.0,0.39,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,100.5,30.0,0.0,0.6,14.0,6.98,0.0,5.0,0.0,0.0,0.0,41.0,0.0,0.0,67.0,0.0
4,10000,76.0,31.424716,28.55,22.0,2.395573,36.75,10.7,25.5,1.5,12.090909,274.985885,3.609838,7.75,1.0,44.5,86.03802,0.5,135.0,98.25,121.75,69.090909,1.4,3.9,48.227273,7.70575,97.562013,98.545455,0.608505,103.5,30.3,77.090909,1.197754,5.817339,123.0,7.39,0.0,5.7,0.0,0.0,1.0,0.6,1.0,3.0,8.0,0.0,0.0,0.5,0.0,6.0,0.0,0.2,8.0,1.0,31.5,30.0,0.0,0.4,18.0,0.0,0.0,3.0,0.0,1.0,2.0,26.0,0.0,0.0,57.0,0.07
5,10002,73.0,19.0,31.3,18.0,3.005,37.0,10.4,23.828877,-0.881882,19.625,161.0,3.0,10.3,0.98,40.5,41.0,0.8,83.0,92.402792,127.166667,69.818182,2.1,4.475,48.5,3.12,38.0,99.181818,0.608505,107.0,30.3,67.090909,0.8,5.817339,132.090909,7.375,0.0,0.0,0.0,2.45,2.0,0.0,0.0,0.0,6.0,0.0,0.0,0.0,0.0,11.0,0.0,0.4,0.0,0.0,24.0,32.0,0.0,0.45,18.5,4.02,0.0,4.0,0.0,2.0,0.0,19.0,0.0,0.0,65.0,0.09
6,10006,51.0,34.020707,37.275808,21.317991,2.395311,37.5,10.949367,23.963837,-0.580284,18.888889,268.857981,3.661243,11.297389,1.438611,41.414587,176.596782,0.538963,212.857003,94.235683,200.5,70.555556,1.987571,4.112962,48.714286,7.742205,98.531877,96.666667,1.586346,105.757518,32.484703,82.0,1.816378,7.683001,117.888889,7.374335,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,10.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,45.0,28.0,0.0,0.0,16.0,0.0,0.0,8.0,0.0,0.0,0.0,12.0,0.0,0.0,59.0,0.0
7,10007,60.0,32.960816,39.756291,23.358213,2.37691,38.0,10.672379,24.035953,-0.727915,21.909091,270.562388,2.4,11.619513,1.50062,41.495807,146.01825,1.069466,209.292544,93.827695,87.0,108.181818,1.6,3.6,88.363636,7.524813,90.714723,94.909091,1.046905,105.700469,31.858225,79.909091,1.435074,0.08,139.363636,7.372458,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,34.0,0.0,0.0,37.0,0.0,0.0,4.0,0.0,0.0,0.0,34.0,0.0,0.0,46.0,0.0
8,10009,69.0,32.960816,86.05,15.0,2.37691,37.25,12.2,21.0,-0.727915,22.5,270.562388,3.1,8.7,0.5,41.495807,17.0,1.069466,182.0,93.827695,109.0,65.909091,1.2,4.1,65.227273,8.1,64.0,97.090909,1.046905,104.0,34.0,97.727273,0.8,6.996616,90.909091,7.372458,0.0,0.3,0.0,0.0,1.0,0.0,0.0,0.0,10.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,27.0,0.0,0.0,13.5,0.0,0.0,5.0,0.0,0.0,0.0,23.0,0.0,0.0,28.0,0.0
9,1001,36.0,35.270246,31.2,10.0,1.8,37.666667,10.4,31.0,7.5,13.363636,282.406566,2.1,8.7,0.5,40.25,179.917422,0.416667,205.0,95.156566,98.666667,85.909091,2.0,3.1,69.181818,8.5,101.242051,100.0,1.97375,105.0,29.8,106.727273,2.196677,3.712206,113.090909,7.515,0.0,0.0,0.0,0.8,1.0,0.0,0.0,5.0,10.0,0.0,0.0,0.0,0.0,27.0,0.0,0.1,0.0,0.0,16.0,29.0,0.0,0.4,21.0,0.0,0.0,0.0,0.0,0.0,1.5,11.0,0.0,0.0,37.0,0.26


## Mean value imputation

In [6]:
def mean_impute_df(df):
    return df.fillna(df.mean())

X_df = mean_impute_df(X_df)
display(X_df.head(30))

X_test_df = mean_impute_df(X_test_df)

Unnamed: 0,pid,Age,EtCO2,PTT,BUN,Lactate,Temp,Hgb,HCO3,BaseExcess,RRate,Fibrinogen,Phosphate,WBC,Creatinine,PaCO2,AST,FiO2,Platelets,SaO2,Glucose,ABPm,Magnesium,Potassium,ABPd,Calcium,Alkalinephos,SpO2,Bilirubin_direct,Chloride,Hct,Heartrate,Bilirubin_total,TroponinI,ABPs,pH,EtCO2_Diff,PTT_Diff,BUN_Diff,Lactate_Diff,Temp_Diff,Hgb_Diff,HCO3_Diff,BaseExcess_Diff,RRate_Diff,Fibrinogen_Diff,Phosphate_Diff,WBC_Diff,Creatinine_Diff,PaCO2_Diff,AST_Diff,FiO2_Diff,Platelets_Diff,SaO2_Diff,Glucose_Diff,ABPm_Diff,Magnesium_Diff,Potassium_Diff,ABPd_Diff,Calcium_Diff,Alkalinephos_Diff,SpO2_Diff,Bilirubin_direct_Diff,Chloride_Diff,Hct_Diff,Heartrate_Diff,Bilirubin_total_Diff,TroponinI_Diff,ABPs_Diff,pH_Diff
0,1,34.0,35.270246,36.296005,12.0,2.267844,36.75,8.566667,25.333333,-0.666667,17.0,282.406566,4.6,5.233333,0.5,43.333333,179.917422,0.425,143.0,95.156566,120.0,68.333333,1.8,4.0,50.25,7.6,101.242051,100.0,1.97375,112.0,23.2,77.083333,2.196677,3.712206,114.5,7.37,0.0,0.0,0.0,0.0,2.0,0.2,2.0,2.0,6.0,0.0,0.0,1.6,0.0,3.0,0.0,0.1,0.0,0.0,0.0,31.0,0.9,0.3,23.5,0.0,0.0,0.0,0.0,3.0,2.4,41.0,0.0,0.0,42.0,0.08
1,10,71.0,31.424716,27.8,12.0,2.395573,36.0,14.6,23.828877,-0.881882,18.090909,274.985885,2.5,11.5,0.82,40.970572,20.0,0.552612,207.0,92.402792,152.0,101.727273,1.5,3.2,83.272727,8.6,68.0,98.0,0.608505,106.167733,42.1,78.818182,1.3,0.01,132.909091,7.375165,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,37.0,0.0,0.0,35.0,0.0,0.0,3.0,0.0,0.0,0.0,19.0,0.0,0.0,39.0,0.0
2,100,68.0,32.960816,20.9,21.0,2.37691,36.25,12.5,27.0,-0.727915,14.833333,270.562388,3.5,12.5,1.1,41.495807,146.01825,1.069466,204.0,93.827695,243.0,81.833333,1.7,3.6,62.833333,9.0,90.714723,96.5,1.046905,101.0,36.8,109.083333,1.435074,6.996616,117.0,7.372458,0.0,0.0,0.0,0.0,3.0,0.0,0.0,0.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,32.0,0.0,0.0,20.0,0.0,0.0,8.0,0.0,0.0,0.0,37.0,0.0,0.0,61.0,0.0
3,1000,79.0,31.863636,39.371346,22.0,3.855,36.818182,9.2,23.828877,-0.881882,12.0,274.985885,1.9,19.6,0.96,44.0,86.03802,0.4,158.0,98.0,128.625,83.454545,2.0,3.966667,62.818182,3.463333,97.562013,98.818182,0.608505,106.167733,27.3,86.363636,1.197754,5.817339,141.909091,7.3,11.0,0.0,0.0,0.39,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,100.5,30.0,0.0,0.6,14.0,6.98,0.0,5.0,0.0,0.0,0.0,41.0,0.0,0.0,67.0,0.0
4,10000,76.0,31.424716,28.55,22.0,2.395573,36.75,10.7,25.5,1.5,12.090909,274.985885,3.609838,7.75,1.0,44.5,86.03802,0.5,135.0,98.25,121.75,69.090909,1.4,3.9,48.227273,7.70575,97.562013,98.545455,0.608505,103.5,30.3,77.090909,1.197754,5.817339,123.0,7.39,0.0,5.7,0.0,0.0,1.0,0.6,1.0,3.0,8.0,0.0,0.0,0.5,0.0,6.0,0.0,0.2,8.0,1.0,31.5,30.0,0.0,0.4,18.0,0.0,0.0,3.0,0.0,1.0,2.0,26.0,0.0,0.0,57.0,0.07
5,10002,73.0,19.0,31.3,18.0,3.005,37.0,10.4,23.828877,-0.881882,19.625,161.0,3.0,10.3,0.98,40.5,41.0,0.8,83.0,92.402792,127.166667,69.818182,2.1,4.475,48.5,3.12,38.0,99.181818,0.608505,107.0,30.3,67.090909,0.8,5.817339,132.090909,7.375,0.0,0.0,0.0,2.45,2.0,0.0,0.0,0.0,6.0,0.0,0.0,0.0,0.0,11.0,0.0,0.4,0.0,0.0,24.0,32.0,0.0,0.45,18.5,4.02,0.0,4.0,0.0,2.0,0.0,19.0,0.0,0.0,65.0,0.09
6,10006,51.0,34.020707,37.275808,21.317991,2.395311,37.5,10.949367,23.963837,-0.580284,18.888889,268.857981,3.661243,11.297389,1.438611,41.414587,176.596782,0.538963,212.857003,94.235683,200.5,70.555556,1.987571,4.112962,48.714286,7.742205,98.531877,96.666667,1.586346,105.757518,32.484703,82.0,1.816378,7.683001,117.888889,7.374335,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,10.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,45.0,28.0,0.0,0.0,16.0,0.0,0.0,8.0,0.0,0.0,0.0,12.0,0.0,0.0,59.0,0.0
7,10007,60.0,32.960816,39.756291,23.358213,2.37691,38.0,10.672379,24.035953,-0.727915,21.909091,270.562388,2.4,11.619513,1.50062,41.495807,146.01825,1.069466,209.292544,93.827695,87.0,108.181818,1.6,3.6,88.363636,7.524813,90.714723,94.909091,1.046905,105.700469,31.858225,79.909091,1.435074,0.08,139.363636,7.372458,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,34.0,0.0,0.0,37.0,0.0,0.0,4.0,0.0,0.0,0.0,34.0,0.0,0.0,46.0,0.0
8,10009,69.0,32.960816,86.05,15.0,2.37691,37.25,12.2,21.0,-0.727915,22.5,270.562388,3.1,8.7,0.5,41.495807,17.0,1.069466,182.0,93.827695,109.0,65.909091,1.2,4.1,65.227273,8.1,64.0,97.090909,1.046905,104.0,34.0,97.727273,0.8,6.996616,90.909091,7.372458,0.0,0.3,0.0,0.0,1.0,0.0,0.0,0.0,10.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,27.0,0.0,0.0,13.5,0.0,0.0,5.0,0.0,0.0,0.0,23.0,0.0,0.0,28.0,0.0
9,1001,36.0,35.270246,31.2,10.0,1.8,37.666667,10.4,31.0,7.5,13.363636,282.406566,2.1,8.7,0.5,40.25,179.917422,0.416667,205.0,95.156566,98.666667,85.909091,2.0,3.1,69.181818,8.5,101.242051,100.0,1.97375,105.0,29.8,106.727273,2.196677,3.712206,113.090909,7.515,0.0,0.0,0.0,0.8,1.0,0.0,0.0,5.0,10.0,0.0,0.0,0.0,0.0,27.0,0.0,0.1,0.0,0.0,16.0,29.0,0.0,0.4,21.0,0.0,0.0,0.0,0.0,0.0,1.5,11.0,0.0,0.0,37.0,0.26


## Scaling / normalization

In [7]:
from sklearn import preprocessing

# Scale the dataframe
def scale_df(df):
    scaler = preprocessing.MinMaxScaler()
    df.loc[:, df.columns != 'pid'] = scaler.fit_transform(df.loc[:, df.columns != 'pid'])

scale_df(X_df)
display(X_df.head(30))

scale_df(X_test_df)

Unnamed: 0,pid,Age,EtCO2,PTT,BUN,Lactate,Temp,Hgb,HCO3,BaseExcess,RRate,Fibrinogen,Phosphate,WBC,Creatinine,PaCO2,AST,FiO2,Platelets,SaO2,Glucose,ABPm,Magnesium,Potassium,ABPd,Calcium,Alkalinephos,SpO2,Bilirubin_direct,Chloride,Hct,Heartrate,Bilirubin_total,TroponinI,ABPs,pH,EtCO2_Diff,PTT_Diff,BUN_Diff,Lactate_Diff,Temp_Diff,Hgb_Diff,HCO3_Diff,BaseExcess_Diff,RRate_Diff,Fibrinogen_Diff,Phosphate_Diff,WBC_Diff,Creatinine_Diff,PaCO2_Diff,AST_Diff,FiO2_Diff,Platelets_Diff,SaO2_Diff,Glucose_Diff,ABPm_Diff,Magnesium_Diff,Potassium_Diff,ABPd_Diff,Calcium_Diff,Alkalinephos_Diff,SpO2_Diff,Bilirubin_direct_Diff,Chloride_Diff,Hct_Diff,Heartrate_Diff,Bilirubin_total_Diff,TroponinI_Diff,ABPs_Diff,pH_Diff
0,1,0.223529,0.280781,0.083209,0.041199,0.083976,0.489691,0.253268,0.43295,0.477124,0.167539,0.237662,0.264033,0.013268,0.009569,0.383142,0.018023,0.000425,0.059965,0.925486,0.130575,0.267541,0.203125,0.232955,0.164146,0.336735,0.023356,1.0,0.092673,0.565445,0.242224,0.369278,0.045187,0.008414,0.383223,0.615063,0.0,0.0,0.0,0.0,0.125,0.018349,0.068966,0.018519,0.086957,0.0,0.0,0.013104,0.0,0.045455,0.0,2.5e-05,0.0,0.0,0.0,0.133621,0.101124,0.048309,0.095142,0.0,0.0,0.0,0.0,0.078947,0.07717,0.325397,0.0,0.0,0.237288,0.133333
1,10,0.658824,0.238052,0.046761,0.041199,0.089427,0.412371,0.54902,0.398365,0.472904,0.178962,0.230534,0.133056,0.029465,0.017225,0.355984,0.001546,0.000552,0.087575,0.88312,0.171946,0.501481,0.15625,0.142045,0.40068,0.387755,0.014656,0.972603,0.028245,0.473839,0.598492,0.382,0.025862,0.0,0.494322,0.621546,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.115942,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.159483,0.0,0.0,0.1417,0.0,0.0,0.038462,0.0,0.0,0.0,0.150794,0.0,0.0,0.220339,0.0
2,100,0.623529,0.25512,0.01716,0.074906,0.088631,0.438144,0.446078,0.471264,0.475923,0.144852,0.226285,0.195426,0.03205,0.023923,0.362021,0.01453,0.001069,0.086281,0.905041,0.289593,0.362115,0.1875,0.1875,0.254278,0.408163,0.020601,0.952055,0.048934,0.39267,0.498586,0.603944,0.028773,0.015879,0.39831,0.618148,0.0,0.0,0.0,0.0,0.1875,0.0,0.0,0.0,0.072464,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.137931,0.0,0.0,0.080972,0.0,0.0,0.102564,0.0,0.0,0.0,0.293651,0.0,0.0,0.344633,0.0
3,1000,0.752941,0.242929,0.096402,0.078652,0.151707,0.49672,0.284314,0.398365,0.472904,0.115183,0.230534,0.095634,0.050401,0.020574,0.390805,0.00835,0.0004,0.066437,0.969231,0.141726,0.373472,0.234375,0.229167,0.254169,0.12568,0.022393,0.983811,0.028245,0.473839,0.31951,0.437333,0.023658,0.013199,0.548637,0.527197,0.282051,0.0,0.0,0.01413,0.125,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.120504,0.12931,0.0,0.096618,0.05668,0.375875,0.0,0.064103,0.0,0.0,0.0,0.325397,0.0,0.0,0.378531,0.0
4,10000,0.717647,0.238052,0.049979,0.078652,0.089427,0.489691,0.357843,0.436782,0.519608,0.116135,0.230534,0.202277,0.019773,0.021531,0.396552,0.00835,0.0005,0.056514,0.973077,0.132838,0.272849,0.140625,0.221591,0.149658,0.34213,0.022393,0.980075,0.028245,0.431937,0.37606,0.369333,0.023658,0.013199,0.43452,0.640167,0.0,0.024989,0.0,0.0,0.0625,0.055046,0.034483,0.027778,0.115942,0.0,0.0,0.004095,0.0,0.090909,0.0,5e-05,0.015595,0.015385,0.03777,0.12931,0.0,0.064412,0.072874,0.0,0.0,0.038462,0.0,0.026316,0.064309,0.206349,0.0,0.0,0.322034,0.116667
5,10002,0.682353,0.1,0.061776,0.06367,0.115434,0.515464,0.343137,0.398365,0.472904,0.195026,0.121037,0.164241,0.026363,0.021053,0.350575,0.003709,0.0008,0.034081,0.88312,0.139841,0.277943,0.25,0.286932,0.151612,0.108163,0.006805,0.988792,0.028245,0.486911,0.37606,0.296,0.015086,0.013199,0.489384,0.621339,0.0,0.0,0.0,0.088768,0.125,0.0,0.0,0.0,0.086957,0.0,0.0,0.0,0.0,0.166667,0.0,0.0001,0.0,0.0,0.028777,0.137931,0.0,0.072464,0.074899,0.216478,0.0,0.051282,0.0,0.052632,0.0,0.150794,0.0,0.0,0.367232,0.15
6,10006,0.423529,0.266897,0.087412,0.076097,0.089416,0.56701,0.370067,0.401468,0.478818,0.187318,0.224647,0.205483,0.028941,0.032024,0.361087,0.017681,0.000539,0.090102,0.911318,0.234648,0.283109,0.232433,0.245791,0.153146,0.34399,0.022646,0.954338,0.074391,0.467396,0.417242,0.405333,0.036991,0.017439,0.403675,0.620504,0.0,0.0,0.0,0.0,0.0625,0.0,0.0,0.0,0.144928,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.053957,0.12069,0.0,0.0,0.064777,0.0,0.0,0.102564,0.0,0.0,0.0,0.095238,0.0,0.0,0.333333,0.0
7,10007,0.529412,0.25512,0.098054,0.083739,0.088631,0.618557,0.356489,0.403125,0.475923,0.218943,0.226285,0.126819,0.029774,0.033508,0.362021,0.01453,0.001069,0.088565,0.905041,0.087912,0.546698,0.171875,0.1875,0.437145,0.332899,0.020601,0.930262,0.048934,0.4665,0.405433,0.39,0.028773,0.000159,0.533275,0.618148,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.057971,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.146552,0.0,0.0,0.149798,0.0,0.0,0.051282,0.0,0.0,0.0,0.269841,0.0,0.0,0.259887,0.0
8,10009,0.635294,0.25512,0.296654,0.052434,0.088631,0.541237,0.431373,0.333333,0.475923,0.225131,0.226285,0.170478,0.022228,0.009569,0.362021,0.001236,0.001069,0.07679,0.905041,0.116354,0.250559,0.109375,0.244318,0.271425,0.362245,0.013609,0.960149,0.048934,0.439791,0.445806,0.520667,0.015086,0.015879,0.240851,0.618148,0.0,0.001315,0.0,0.0,0.0625,0.0,0.0,0.0,0.144928,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.116379,0.0,0.0,0.054656,0.0,0.0,0.064103,0.0,0.0,0.0,0.18254,0.0,0.0,0.158192,0.0
9,1001,0.247059,0.280781,0.061347,0.033708,0.064011,0.584192,0.343137,0.563218,0.637255,0.129462,0.237662,0.108108,0.022228,0.009569,0.347701,0.018023,0.000417,0.086713,0.925486,0.102995,0.390667,0.234375,0.130682,0.29975,0.382653,0.023356,1.0,0.092673,0.455497,0.366635,0.586667,0.045187,0.008414,0.374719,0.797071,0.0,0.0,0.0,0.028986,0.0625,0.0,0.0,0.046296,0.144928,0.0,0.0,0.0,0.0,0.409091,0.0,2.5e-05,0.0,0.0,0.019185,0.125,0.0,0.064412,0.08502,0.0,0.0,0.0,0.0,0.0,0.048232,0.087302,0.0,0.0,0.20904,0.433333


# SVM Training

In [8]:
#X_df = X_df.iloc[0:2000, :]
#train_labels_df = train_labels_df.iloc[0:2000, :]

# Prepare train set
# Assert that the pids are matching in train features / train labels dfs
assert(X_df.iloc[:, 0].equals(train_labels_df.iloc[:, 0]))

X = X_df.iloc[:, 1:].to_numpy()
y = train_labels_df.iloc[:, 1:].to_numpy()

assert(X.shape[0] == y.shape[0])

In [None]:
from sklearn.multioutput import MultiOutputRegressor
from sklearn.svm import SVR

model = MultiOutputRegressor(SVR(verbose=True, C=100))

model.fit(X, y)

[LibSVM][LibSVM][LibSVM][LibSVM]

# Predictions

In [None]:
def predict(df):
    # Create X_predict by removing 'pid' column
    X_predict = df.iloc[:, 1:].to_numpy()
    
    # Predict
    predictions = model.predict(X_predict)
    predictions[:, :-4] = np.divide(1, 1+np.exp(-predictions[:, :-4]))
    
    # Create predictions df
    predict_labels_df = pd.DataFrame(columns=train_labels_df.columns)
    predict_labels_df[['pid']] = df[['pid']]
    predict_labels_df.loc[:, predict_labels_df.columns != 'pid'] = predictions
    
    return predict_labels_df

## Train set

In [None]:
prediction_labels_df = predict(X_df.iloc[:, :])

print(prediction_labels_df.shape)
display(prediction_labels_df.head(15))
display(train_labels_df.head(15))

In [None]:
prediction_labels_df.to_csv('data/prediction_train.csv', index=False, float_format='%.3f')

In [None]:
train_labels_df.to_csv('data/prediction_gold.csv', index=False, float_format='%.3f')

## Test set

In [None]:
prediction_test_labels_df = predict(X_test_df)

print(prediction_test_labels_df.shape)
display(prediction_test_labels_df.head(15))

In [None]:
prediction_test_labels_df.to_csv('data/prediction.csv', index=False, float_format='%.3f')