In [1]:
import numpy as np
import pandas as pd
import sys
import os
import pickle 

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.metrics import log_loss
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import average_precision_score
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import label_binarize
from sklearn.ensemble import RandomForestClassifier
import scipy.stats as ss

In [2]:
sys.path.append('../utils')
from simple_impute import simple_imputer

# Task Specifics

In [3]:
INTERVENTION = 'vent'
RANDOM = 0
MAX_LEN = 240
SLICE_SIZE = 6
GAP_TIME = 6
PREDICTION_WINDOW = 4
OUTCOME_TYPE = 'all'
NUM_CLASSES = 4

In [4]:
CHUNK_KEY = {'ONSET': 0, 'CONTROL': 1, 'ON_INTERVENTION': 2, 'WEAN': 3}

# Load Data

In [5]:
DATAFILE = 'D:/data/MIMIC_Extract/samples.h5'

In [6]:
X = pd.read_hdf(DATAFILE,'vitalslabs')
Y = pd.read_hdf(DATAFILE,'interventions')
static = pd.read_hdf(DATAFILE,'patients')

# save data by h5py

In [7]:
static.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,gender,ethnicity,age,insurance,admittime,diagnosis_at_admission,dischtime,discharge_location,fullcode_first,dnr_first,...,outtime,los_icu,admission_type,first_careunit,mort_icu,mort_hosp,hospital_expire_flag,hospstay_seq,readmission_30,max_hours
subject_id,hadm_id,icustay_id,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
3,145834,211552,M,WHITE,76.526792,Medicare,2101-10-20 19:08:00,HYPOTENSION,2101-10-31 13:58:00,SNF,1.0,0.0,...,2101-10-26 20:43:09,6.06456,EMERGENCY,MICU,0,0,0,1,0,145
4,185777,294638,F,WHITE,47.845047,Private,2191-03-16 00:28:00,"FEVER,DEHYDRATION,FAILURE TO THRIVE",2191-03-23 18:41:00,HOME WITH HOME IV PROVIDR,1.0,0.0,...,2191-03-17 16:46:31,1.678472,EMERGENCY,MICU,0,0,0,1,0,40
6,107064,228232,F,WHITE,65.942297,Medicare,2175-05-30 07:15:00,CHRONIC RENAL FAILURE/SDA,2175-06-15 16:00:00,HOME HEALTH CARE,1.0,0.0,...,2175-06-03 13:39:54,3.672917,ELECTIVE,SICU,0,0,0,1,0,88
9,150750,220597,M,UNKNOWN/NOT SPECIFIED,41.790228,Medicaid,2149-11-09 13:06:00,HEMORRHAGIC CVA,2149-11-14 10:15:00,DEAD/EXPIRED,1.0,0.0,...,2149-11-14 20:52:14,5.323056,EMERGENCY,MICU,1,1,1,1,0,127
11,194540,229441,F,WHITE,50.148295,Private,2178-04-16 06:18:00,BRAIN MASS,2178-05-11 19:00:00,HOME HEALTH CARE,1.0,0.0,...,2178-04-17 20:21:05,1.58441,EMERGENCY,SICU,0,0,0,1,0,38


In [8]:
X.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,LEVEL2,alanine aminotransferase,alanine aminotransferase,alanine aminotransferase,albumin,albumin,albumin,albumin ascites,albumin ascites,albumin ascites,albumin pleural,...,white blood cell count,white blood cell count urine,white blood cell count urine,white blood cell count urine,ph,ph,ph,ph urine,ph urine,ph urine
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Aggregation Function,count,mean,std,count,mean,std,count,mean,std,count,...,std,count,mean,std,count,mean,std,count,mean,std
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2
3,145834,211552,0,2.0,25.0,0.0,2.0,1.8,0.0,0.0,,,0.0,...,4.012837,0.0,,,9.0,7.4,0.147733,1.0,5.0,
3,145834,211552,1,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
3,145834,211552,2,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,3.0,7.26,0.0,0.0,,
3,145834,211552,3,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
3,145834,211552,4,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,


In [10]:
X.shape

(2235, 312)

In [9]:
Y.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,vent,vaso,adenosine,dobutamine,dopamine,epinephrine,isuprel,milrinone,norepinephrine,phenylephrine,vasopressin,colloid_bolus,crystalloid_bolus,nivdurations
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
3,145834,211552,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
3,145834,211552,1,1,1,0,0,1,0,0,0,0,1,0,0,0,0
3,145834,211552,2,1,1,0,0,1,0,0,0,0,1,0,0,0,0
3,145834,211552,3,1,1,0,0,0,0,0,0,0,1,0,0,0,0
3,145834,211552,4,1,1,0,0,0,0,0,0,1,1,0,0,0,0


In [12]:
Y.shape

(12714, 14)

In [13]:
patient_num = 3

In [15]:
idx = pd.IndexSlice
Y = Y.loc[idx[Y.index.levels[0][:patient_num]]]

In [16]:
Y = Y.loc[:, (Y != 0).any(axis = 0)]
X = X.loc[idx[X.index.levels[0][:patient_num]]]

In [17]:
Y

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,vent,vaso,dopamine,norepinephrine,phenylephrine,crystalloid_bolus,nivdurations
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
3,145834,211552,0,1,0,0,0,0,0,0
3,145834,211552,1,1,1,1,0,1,0,0
3,145834,211552,2,1,1,1,0,1,0,0
3,145834,211552,3,1,1,0,0,1,0,0
3,145834,211552,4,1,1,0,1,1,0,0
3,145834,211552,5,1,1,0,1,1,0,0
3,145834,211552,6,1,1,0,1,1,0,0
3,145834,211552,7,1,1,0,1,1,0,0
3,145834,211552,8,1,1,0,1,1,0,0
3,145834,211552,9,1,1,0,1,1,0,0


In [18]:
X

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,LEVEL2,alanine aminotransferase,alanine aminotransferase,alanine aminotransferase,albumin,albumin,albumin,albumin ascites,albumin ascites,albumin ascites,albumin pleural,...,white blood cell count,white blood cell count urine,white blood cell count urine,white blood cell count urine,ph,ph,ph,ph urine,ph urine,ph urine
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Aggregation Function,count,mean,std,count,mean,std,count,mean,std,count,...,std,count,mean,std,count,mean,std,count,mean,std
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2
3,145834,211552,0,2.0,25.0,0.0,2.0,1.8,0.0,0.0,,,0.0,...,4.012837,0.0,,,9.0,7.4000,0.147733,1.0,5.0,
3,145834,211552,1,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
3,145834,211552,2,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,3.0,7.2600,0.000000,0.0,,
3,145834,211552,3,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
3,145834,211552,4,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
3,145834,211552,5,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,3.0,7.2900,0.000000,0.0,,
3,145834,211552,6,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,3.0,7.3000,0.000000,0.0,,
3,145834,211552,7,0.0,,,0.0,,,0.0,,,0.0,...,0.000000,0.0,,,3.0,7.3300,0.000000,0.0,,
3,145834,211552,8,0.0,,,0.0,,,0.0,,,0.0,...,,1.0,35.0,,0.0,,,1.0,5.0,
3,145834,211552,9,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,


In [19]:
# 删除全部为0的列
Y = Y.loc[:, (Y != 0).any(axis = 0)]
X = X.loc[idx[X.index.levels[0][:patient_num]]]

In [20]:
X

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,LEVEL2,alanine aminotransferase,alanine aminotransferase,alanine aminotransferase,albumin,albumin,albumin,albumin ascites,albumin ascites,albumin ascites,albumin pleural,...,white blood cell count,white blood cell count urine,white blood cell count urine,white blood cell count urine,ph,ph,ph,ph urine,ph urine,ph urine
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Aggregation Function,count,mean,std,count,mean,std,count,mean,std,count,...,std,count,mean,std,count,mean,std,count,mean,std
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2
3,145834,211552,0,2.0,25.0,0.0,2.0,1.8,0.0,0.0,,,0.0,...,4.012837,0.0,,,9.0,7.4000,0.147733,1.0,5.0,
3,145834,211552,1,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
3,145834,211552,2,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,3.0,7.2600,0.000000,0.0,,
3,145834,211552,3,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
3,145834,211552,4,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
3,145834,211552,5,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,3.0,7.2900,0.000000,0.0,,
3,145834,211552,6,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,3.0,7.3000,0.000000,0.0,,
3,145834,211552,7,0.0,,,0.0,,,0.0,,,0.0,...,0.000000,0.0,,,3.0,7.3300,0.000000,0.0,,
3,145834,211552,8,0.0,,,0.0,,,0.0,,,0.0,...,,1.0,35.0,,0.0,,,1.0,5.0,
3,145834,211552,9,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,


In [21]:
Y

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,vent,vaso,dopamine,norepinephrine,phenylephrine,crystalloid_bolus,nivdurations
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
3,145834,211552,0,1,0,0,0,0,0,0
3,145834,211552,1,1,1,1,0,1,0,0
3,145834,211552,2,1,1,1,0,1,0,0
3,145834,211552,3,1,1,0,0,1,0,0
3,145834,211552,4,1,1,0,1,1,0,0
3,145834,211552,5,1,1,0,1,1,0,0
3,145834,211552,6,1,1,0,1,1,0,0
3,145834,211552,7,1,1,0,1,1,0,0
3,145834,211552,8,1,1,0,1,1,0,0
3,145834,211552,9,1,1,0,1,1,0,0


# simple_imputer

In [24]:
def simple_imputer(df,train_subj):
    idx = pd.IndexSlice
    df = df.copy()
    
    df_out = df.loc[:, idx[:, ['mean', 'count']]]
    icustay_means = df_out.loc[:, idx[:, 'mean']].groupby(ID_COLS).mean()
    global_means = df_out.loc[idx[train_subj,:], idx[:, 'mean']].mean(axis=0)
    
    df_out.loc[:,idx[:,'mean']] = df_out.loc[:,idx[:,'mean']].groupby(ID_COLS).fillna(
        method='ffill'
    ).groupby(ID_COLS).fillna(icustay_means).fillna(global_means)
    
    df_out.loc[:, idx[:, 'count']] = (df.loc[:, idx[:, 'count']] > 0).astype(float)
    df_out.rename(columns={'count': 'mask'}, level='Aggregation Function', inplace=True)
    
    is_absent = (1 - df_out.loc[:, idx[:, 'mask']])
    hours_of_absence = is_absent.cumsum()
    time_since_measured = hours_of_absence - hours_of_absence[is_absent==0].fillna(method='ffill')
    time_since_measured.rename(columns={'mask': 'time_since_measured'}, level='Aggregation Function', inplace=True)

    df_out = pd.concat((df_out, time_since_measured), axis=1)
    df_out.loc[:, idx[:, 'time_since_measured']] = df_out.loc[:, idx[:, 'time_since_measured']].fillna(100)
    
    df_out.sort_index(axis=1, inplace=True)
    return df_out

In [25]:
train_ids, test_ids = train_test_split(static.reset_index(), test_size=0.2, random_state=RANDOM, stratify=static['mort_hosp'])
split_train_ids, val_ids = train_test_split(train_ids, test_size=0.125, random_state=RANDOM, stratify=train_ids['mort_hosp'])

In [27]:
ID_COLS = ['subject_id', 'hadm_id', 'icustay_id']

In [44]:
# Imputation and Standardization of Time Series Features¶
X_clean = simple_imputer(X,train_ids['subject_id'])

In [45]:
X_clean

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,LEVEL2,alanine aminotransferase,alanine aminotransferase,alanine aminotransferase,albumin,albumin,albumin,albumin ascites,albumin ascites,albumin ascites,albumin pleural,...,venous pvo2,weight,weight,weight,white blood cell count,white blood cell count,white blood cell count,white blood cell count urine,white blood cell count urine,white blood cell count urine
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Aggregation Function,mask,mean,time_since_measured,mask,mean,time_since_measured,mask,mean,time_since_measured,mask,...,time_since_measured,mask,mean,time_since_measured,mask,mean,time_since_measured,mask,mean,time_since_measured
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2
3,145834,211552,0,1.0,25.0,0.0,1.0,1.8,0.0,0.0,,100.0,0.0,...,100.0,0.0,107.000000,100.0,1.0,14.842857,0.0,0.0,35.0,100.0
3,145834,211552,1,0.0,25.0,1.0,0.0,1.8,1.0,0.0,,100.0,0.0,...,100.0,0.0,107.000000,100.0,0.0,14.842857,1.0,0.0,35.0,100.0
3,145834,211552,2,0.0,25.0,2.0,0.0,1.8,2.0,0.0,,100.0,0.0,...,100.0,0.0,107.000000,100.0,0.0,14.842857,2.0,0.0,35.0,100.0
3,145834,211552,3,0.0,25.0,3.0,0.0,1.8,3.0,0.0,,100.0,0.0,...,100.0,0.0,107.000000,100.0,0.0,14.842857,3.0,0.0,35.0,100.0
3,145834,211552,4,0.0,25.0,4.0,0.0,1.8,4.0,0.0,,100.0,0.0,...,100.0,0.0,107.000000,100.0,0.0,14.842857,4.0,0.0,35.0,100.0
3,145834,211552,5,0.0,25.0,5.0,0.0,1.8,5.0,0.0,,100.0,0.0,...,100.0,0.0,107.000000,100.0,0.0,14.842857,5.0,0.0,35.0,100.0
3,145834,211552,6,0.0,25.0,6.0,0.0,1.8,6.0,0.0,,100.0,0.0,...,100.0,0.0,107.000000,100.0,0.0,14.842857,6.0,0.0,35.0,100.0
3,145834,211552,7,0.0,25.0,7.0,0.0,1.8,7.0,0.0,,100.0,0.0,...,100.0,0.0,107.000000,100.0,1.0,24.400000,0.0,0.0,35.0,100.0
3,145834,211552,8,0.0,25.0,8.0,0.0,1.8,8.0,0.0,,100.0,0.0,...,100.0,0.0,107.000000,100.0,0.0,24.400000,1.0,1.0,35.0,0.0
3,145834,211552,9,0.0,25.0,9.0,0.0,1.8,9.0,0.0,,100.0,0.0,...,100.0,0.0,107.000000,100.0,0.0,24.400000,2.0,0.0,35.0,1.0


# normalize

In [46]:
def minmax(x):# normalize
    mins = x.min()
    maxes = x.max()
    x_std = (x - mins) / (maxes - mins)
    return x_std

def std_time_since_measurement(x):
    idx = pd.IndexSlice
    x = np.where(x==100, 0, x)
    means = x.mean()
    stds = x.std()
    x_std = (x - means)/stds
    return x_std

In [47]:
idx = pd.IndexSlice
X_std = X_clean.copy()
X_std.loc[:,idx[:,'mean']] = X_std.loc[:,idx[:,'mean']].apply(lambda x: minmax(x))
X_std.loc[:,idx[:,'time_since_measured']] = X_std.loc[:,idx[:,'time_since_measured']].apply(lambda x: std_time_since_measurement(x))

  if sys.path[0] == '':


In [48]:
X_std

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,LEVEL2,alanine aminotransferase,alanine aminotransferase,alanine aminotransferase,albumin,albumin,albumin,albumin ascites,albumin ascites,albumin ascites,albumin pleural,...,venous pvo2,weight,weight,weight,white blood cell count,white blood cell count,white blood cell count,white blood cell count urine,white blood cell count urine,white blood cell count urine
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Aggregation Function,mask,mean,time_since_measured,mask,mean,time_since_measured,mask,mean,time_since_measured,mask,...,time_since_measured,mask,mean,time_since_measured,mask,mean,time_since_measured,mask,mean,time_since_measured
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2
3,145834,211552,0,1.0,1.00,-1.012686,1.0,0.0,-1.187512,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,1.0,0.555482,-1.426496,0.0,,-1.619531
3,145834,211552,1,0.0,1.00,-0.981802,0.0,0.0,-1.164153,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,0.555482,-1.277869,0.0,,-1.619531
3,145834,211552,2,0.0,1.00,-0.950918,0.0,0.0,-1.140793,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,0.555482,-1.129242,0.0,,-1.619531
3,145834,211552,3,0.0,1.00,-0.920034,0.0,0.0,-1.117434,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,0.555482,-0.980615,0.0,,-1.619531
3,145834,211552,4,0.0,1.00,-0.889150,0.0,0.0,-1.094075,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,0.555482,-0.831988,0.0,,-1.619531
3,145834,211552,5,0.0,1.00,-0.858266,0.0,0.0,-1.070716,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,0.555482,-0.683361,0.0,,-1.619531
3,145834,211552,6,0.0,1.00,-0.827381,0.0,0.0,-1.047356,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,0.555482,-0.534734,0.0,,-1.619531
3,145834,211552,7,0.0,1.00,-0.796497,0.0,0.0,-1.023997,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,1.0,1.000000,-1.426496,0.0,,-1.619531
3,145834,211552,8,0.0,1.00,-0.765613,0.0,0.0,-1.000638,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,1.000000,-1.277869,1.0,,-1.619531
3,145834,211552,9,0.0,1.00,-0.734729,0.0,0.0,-0.977279,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,1.000000,-1.129242,0.0,,-1.607003


In [49]:
X_std.columns = X_std.columns.droplevel(-1)

In [51]:
X_std

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,LEVEL2,alanine aminotransferase,alanine aminotransferase,alanine aminotransferase,albumin,albumin,albumin,albumin ascites,albumin ascites,albumin ascites,albumin pleural,...,venous pvo2,weight,weight,weight,white blood cell count,white blood cell count,white blood cell count,white blood cell count urine,white blood cell count urine,white blood cell count urine
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
3,145834,211552,0,1.0,1.00,-1.012686,1.0,0.0,-1.187512,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,1.0,0.555482,-1.426496,0.0,,-1.619531
3,145834,211552,1,0.0,1.00,-0.981802,0.0,0.0,-1.164153,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,0.555482,-1.277869,0.0,,-1.619531
3,145834,211552,2,0.0,1.00,-0.950918,0.0,0.0,-1.140793,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,0.555482,-1.129242,0.0,,-1.619531
3,145834,211552,3,0.0,1.00,-0.920034,0.0,0.0,-1.117434,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,0.555482,-0.980615,0.0,,-1.619531
3,145834,211552,4,0.0,1.00,-0.889150,0.0,0.0,-1.094075,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,0.555482,-0.831988,0.0,,-1.619531
3,145834,211552,5,0.0,1.00,-0.858266,0.0,0.0,-1.070716,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,0.555482,-0.683361,0.0,,-1.619531
3,145834,211552,6,0.0,1.00,-0.827381,0.0,0.0,-1.047356,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,0.555482,-0.534734,0.0,,-1.619531
3,145834,211552,7,0.0,1.00,-0.796497,0.0,0.0,-1.023997,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,1.0,1.000000,-1.426496,0.0,,-1.619531
3,145834,211552,8,0.0,1.00,-0.765613,0.0,0.0,-1.000638,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,1.000000,-1.277869,1.0,,-1.619531
3,145834,211552,9,0.0,1.00,-0.734729,0.0,0.0,-0.977279,0.0,,,0.0,...,-1.423251,0.0,0.981618,-1.302235,0.0,1.000000,-1.129242,0.0,,-1.607003


In [52]:
def categorize_age(age):
    if age > 10 and age <= 30: 
        cat = 1
    elif age > 30 and age <= 50:
        cat = 2
    elif age > 50 and age <= 70:
        cat = 3
    else: 
        cat = 4
    return cat

def categorize_ethnicity(ethnicity):
    if 'AMERICAN INDIAN' in ethnicity:
        ethnicity = 'AMERICAN INDIAN'
    elif 'ASIAN' in ethnicity:
        ethnicity = 'ASIAN'
    elif 'WHITE' in ethnicity:
        ethnicity = 'WHITE'
    elif 'HISPANIC' in ethnicity:
        ethnicity = 'HISPANIC/LATINO'
    elif 'BLACK' in ethnicity:
        ethnicity = 'BLACK'
    else: 
        ethnicity = 'OTHER'
    return ethnicity

In [53]:
static_to_keep = static[['gender', 'age', 'ethnicity', 'first_careunit', 'intime']]
static_to_keep.loc[:, 'intime'] = static_to_keep['intime'].astype('datetime64').apply(lambda x : x.hour)
static_to_keep.loc[:, 'age'] = static_to_keep['age'].apply(categorize_age)
static_to_keep.loc[:, 'ethnicity'] = static_to_keep['ethnicity'].apply(categorize_ethnicity)
static_to_keep = pd.get_dummies(static_to_keep, columns = ['gender', 'age', 'ethnicity', 'first_careunit'])

In [54]:
static_to_keep

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,intime,gender_F,gender_M,age_1,age_2,age_3,age_4,ethnicity_AMERICAN INDIAN,ethnicity_ASIAN,ethnicity_BLACK,ethnicity_HISPANIC/LATINO,ethnicity_OTHER,ethnicity_WHITE,first_careunit_CCU,first_careunit_CSRU,first_careunit_MICU,first_careunit_SICU,first_careunit_TSICU
subject_id,hadm_id,icustay_id,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
3,145834,211552,19,0,1,0,0,0,1,0,0,0,0,0,1,0,0,1,0,0
4,185777,294638,0,1,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,0
6,107064,228232,21,1,0,0,0,1,0,0,0,0,0,0,1,0,0,0,1,0
9,150750,220597,13,0,1,0,1,0,0,0,0,0,0,1,0,0,0,1,0,0
11,194540,229441,6,1,0,0,0,1,0,0,0,0,0,0,1,0,0,0,1,0
12,112213,232669,2,0,1,0,0,0,1,0,0,0,0,0,1,0,0,0,1,0
13,143045,263738,18,1,0,0,1,0,0,0,0,0,0,0,1,1,0,0,0,0
17,194023,277042,16,1,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,0
18,188822,298129,11,0,1,0,0,1,0,0,0,0,0,0,1,1,0,0,0,0
19,109235,273430,16,0,1,0,0,0,1,0,0,0,0,0,1,0,0,0,0,1


In [55]:
X_merge = pd.merge(X_std.reset_index(), static_to_keep.reset_index(), on=['subject_id','icustay_id','hadm_id'])

In [56]:
X_merge

Unnamed: 0,subject_id,hadm_id,icustay_id,hours_in,alanine aminotransferase,alanine aminotransferase.1,alanine aminotransferase.2,albumin,albumin.1,albumin.2,...,ethnicity_ASIAN,ethnicity_BLACK,ethnicity_HISPANIC/LATINO,ethnicity_OTHER,ethnicity_WHITE,first_careunit_CCU,first_careunit_CSRU,first_careunit_MICU,first_careunit_SICU,first_careunit_TSICU
0,3,145834,211552,0,1.0,1.00,-1.012686,1.0,0.0,-1.187512,...,0,0,0,0,1,0,0,1,0,0
1,3,145834,211552,1,0.0,1.00,-0.981802,0.0,0.0,-1.164153,...,0,0,0,0,1,0,0,1,0,0
2,3,145834,211552,2,0.0,1.00,-0.950918,0.0,0.0,-1.140793,...,0,0,0,0,1,0,0,1,0,0
3,3,145834,211552,3,0.0,1.00,-0.920034,0.0,0.0,-1.117434,...,0,0,0,0,1,0,0,1,0,0
4,3,145834,211552,4,0.0,1.00,-0.889150,0.0,0.0,-1.094075,...,0,0,0,0,1,0,0,1,0,0
5,3,145834,211552,5,0.0,1.00,-0.858266,0.0,0.0,-1.070716,...,0,0,0,0,1,0,0,1,0,0
6,3,145834,211552,6,0.0,1.00,-0.827381,0.0,0.0,-1.047356,...,0,0,0,0,1,0,0,1,0,0
7,3,145834,211552,7,0.0,1.00,-0.796497,0.0,0.0,-1.023997,...,0,0,0,0,1,0,0,1,0,0
8,3,145834,211552,8,0.0,1.00,-0.765613,0.0,0.0,-1.000638,...,0,0,0,0,1,0,0,1,0,0
9,3,145834,211552,9,0.0,1.00,-0.734729,0.0,0.0,-0.977279,...,0,0,0,0,1,0,0,1,0,0


In [57]:
abs_time = (X_merge['intime'] + X_merge['hours_in'])%24

In [58]:
X_merge.insert(4, 'absolute_time', abs_time)

In [59]:
X_merge.drop('intime', axis=1, inplace=True)

In [60]:
X_merge = X_merge.set_index(['subject_id','icustay_id','hadm_id','hours_in'])

In [61]:
X_merge

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,absolute_time,alanine aminotransferase,alanine aminotransferase,alanine aminotransferase,albumin,albumin,albumin,albumin ascites,albumin ascites,albumin ascites,...,ethnicity_ASIAN,ethnicity_BLACK,ethnicity_HISPANIC/LATINO,ethnicity_OTHER,ethnicity_WHITE,first_careunit_CCU,first_careunit_CSRU,first_careunit_MICU,first_careunit_SICU,first_careunit_TSICU
subject_id,icustay_id,hadm_id,hours_in,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
3,211552,145834,0,19,1.0,1.00,-1.012686,1.0,0.0,-1.187512,0.0,,,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,1,20,0.0,1.00,-0.981802,0.0,0.0,-1.164153,0.0,,,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,2,21,0.0,1.00,-0.950918,0.0,0.0,-1.140793,0.0,,,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,3,22,0.0,1.00,-0.920034,0.0,0.0,-1.117434,0.0,,,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,4,23,0.0,1.00,-0.889150,0.0,0.0,-1.094075,0.0,,,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,5,0,0.0,1.00,-0.858266,0.0,0.0,-1.070716,0.0,,,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,6,1,0.0,1.00,-0.827381,0.0,0.0,-1.047356,0.0,,,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,7,2,0.0,1.00,-0.796497,0.0,0.0,-1.023997,0.0,,,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,8,3,0.0,1.00,-0.765613,0.0,0.0,-1.000638,0.0,,,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,9,4,0.0,1.00,-0.734729,0.0,0.0,-0.977279,0.0,,,...,0,0,0,0,1,0,0,1,0,0


In [62]:
def create_x_matrix(x):
    zeros = np.zeros((MAX_LEN, x.shape[1]-4))
    x = x.values
    x = x[:(MAX_LEN), 4:]
    zeros[0:x.shape[0], :] = x
    return zeros

def create_y_matrix(y):
    zeros = np.zeros((MAX_LEN, y.shape[1]-4))
    y = y.values
    y = y[:,4:]
    y = y[:MAX_LEN, :]
    zeros[:y.shape[0], :] = y
    return zeros

In [63]:
X_merge = X_merge.dropna(axis=1)

In [64]:
X_merge

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,absolute_time,alanine aminotransferase,alanine aminotransferase,alanine aminotransferase,albumin,albumin,albumin,albumin ascites,albumin pleural,albumin urine,...,ethnicity_ASIAN,ethnicity_BLACK,ethnicity_HISPANIC/LATINO,ethnicity_OTHER,ethnicity_WHITE,first_careunit_CCU,first_careunit_CSRU,first_careunit_MICU,first_careunit_SICU,first_careunit_TSICU
subject_id,icustay_id,hadm_id,hours_in,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
3,211552,145834,0,19,1.0,1.00,-1.012686,1.0,0.0,-1.187512,0.0,0.0,0.0,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,1,20,0.0,1.00,-0.981802,0.0,0.0,-1.164153,0.0,0.0,0.0,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,2,21,0.0,1.00,-0.950918,0.0,0.0,-1.140793,0.0,0.0,0.0,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,3,22,0.0,1.00,-0.920034,0.0,0.0,-1.117434,0.0,0.0,0.0,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,4,23,0.0,1.00,-0.889150,0.0,0.0,-1.094075,0.0,0.0,0.0,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,5,0,0.0,1.00,-0.858266,0.0,0.0,-1.070716,0.0,0.0,0.0,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,6,1,0.0,1.00,-0.827381,0.0,0.0,-1.047356,0.0,0.0,0.0,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,7,2,0.0,1.00,-0.796497,0.0,0.0,-1.023997,0.0,0.0,0.0,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,8,3,0.0,1.00,-0.765613,0.0,0.0,-1.000638,0.0,0.0,0.0,...,0,0,0,0,1,0,0,1,0,0
3,211552,145834,9,4,0.0,1.00,-0.734729,0.0,0.0,-0.977279,0.0,0.0,0.0,...,0,0,0,0,1,0,0,1,0,0


In [65]:
x = np.array(list(X_merge.reset_index().groupby('subject_id').apply(create_x_matrix)))

In [71]:
y = np.array(list(Y.reset_index().groupby('subject_id').apply(create_y_matrix)))

In [66]:
x

array([[[19.  ,  1.  ,  1.  , ...,  1.  ,  0.  ,  0.  ],
        [20.  ,  0.  ,  1.  , ...,  1.  ,  0.  ,  0.  ],
        [21.  ,  0.  ,  1.  , ...,  1.  ,  0.  ,  0.  ],
        ...,
        [ 0.  ,  0.  ,  0.  , ...,  0.  ,  0.  ,  0.  ],
        [ 0.  ,  0.  ,  0.  , ...,  0.  ,  0.  ,  0.  ],
        [ 0.  ,  0.  ,  0.  , ...,  0.  ,  0.  ,  0.  ]],

       [[ 0.  ,  0.  ,  0.95, ...,  1.  ,  0.  ,  0.  ],
        [ 1.  ,  0.  ,  0.95, ...,  1.  ,  0.  ,  0.  ],
        [ 2.  ,  0.  ,  0.95, ...,  1.  ,  0.  ,  0.  ],
        ...,
        [ 0.  ,  0.  ,  0.  , ...,  0.  ,  0.  ,  0.  ],
        [ 0.  ,  0.  ,  0.  , ...,  0.  ,  0.  ,  0.  ],
        [ 0.  ,  0.  ,  0.  , ...,  0.  ,  0.  ,  0.  ]],

       [[21.  ,  1.  ,  0.9 , ...,  0.  ,  1.  ,  0.  ],
        [22.  ,  0.  ,  0.9 , ...,  0.  ,  1.  ,  0.  ],
        [23.  ,  0.  ,  0.9 , ...,  0.  ,  1.  ,  0.  ],
        ...,
        [ 0.  ,  0.  ,  0.  , ...,  0.  ,  0.  ,  0.  ],
        [ 0.  ,  0.  ,  0.  , ...,  0.  ,  0.

In [68]:
x[0][0]

array([ 1.90000000e+01,  1.00000000e+00,  1.00000000e+00, -1.01268629e+00,
        1.00000000e+00,  0.00000000e+00, -1.18751184e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  1.00000000e+00,  3.29113924e-02,
       -1.01268629e+00,  1.00000000e+00,  8.33333333e-01, -1.43949972e+00,
        1.00000000e+00,  1.00000000e+00, -1.01268629e+00,  0.00000000e+00,
        1.00000000e+00, -1.17538923e+00,  1.00000000e+00,  3.33333333e-01,
       -1.43166973e+00,  1.00000000e+00,  4.55696203e-01, -1.01268629e+00,
        1.00000000e+00,  5.42857143e-01, -1.48383086e+00,  1.00000000e+00,
        1.15789548e-01, -1.37795563e+00,  1.00000000e+00,  4.18803419e-01,
       -1.10417442e+00,  0.00000000e+00,  0.00000000e+00,  4.03067916e-01,
       -1.07015113e+00,  0.00000000e+00,  3.27608404e-01, -1.07015113e+00,
        0.00000000e+00,  5.72327039e-01, -1.44425672e+00,  0.00000000e+00,
        2.47064394e-01, -6.51903717e-01,  1.00000000e+00,  8.43750000e-01,
       -1.36385963e+00,  

In [74]:
y[1]

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 1.],
       [0., 0., 0., ..., 0., 0., 1.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [69]:
lengths = np.array(list(X_merge.reset_index().groupby('subject_id').apply(lambda x: x.shape[0])))
lengths

array([146,  41,  89])

In [76]:
RANDOM = 0
MAX_LEN = 240
SLICE_SIZE = 6
GAP_TIME = 0
PREDICTION_WINDOW = 1
OUTCOME_TYPE = 'binary'
NUM_CLASSES = 2
CHUNK_KEY = {'ONSET': 0, 'CONTROL': 1, 'ON_INTERVENTION': 2, 'WEAN': 3}

In [84]:
def make_3d_tensor_slices(X_tensor, Y_tensor, lengths):

    num_patients = X_tensor.shape[0]
    timesteps = X_tensor.shape[1]
    num_features = X_tensor.shape[2]
    num_Y_features = Y_tensor.shape[2]
    # SLICE_SIZE 片大小 6
    X_tensor_new = np.zeros((lengths.sum(), SLICE_SIZE, num_features + num_Y_features))
    Y_tensor_new = np.zeros((lengths.sum(), num_Y_features))
    number_of_1 = 0
    current_row = 0
    # print(num_patients)
    for patient_index in range(num_patients):
        x_patient = X_tensor[patient_index]
        y_patient = Y_tensor[patient_index]
        length = lengths[patient_index]
        for timestep in range(length - PREDICTION_WINDOW - GAP_TIME - SLICE_SIZE):
            x_window = x_patient[timestep:timestep+SLICE_SIZE]
            y_window = y_patient[timestep:timestep+SLICE_SIZE]
            x_window = np.concatenate((x_window, y_window), axis=1)
            result = []
            for i in range(num_Y_features):
                # 隔了 PREDICTION_WINDOW
                result_i = y_patient[timestep+SLICE_SIZE+GAP_TIME:timestep+SLICE_SIZE+GAP_TIME+PREDICTION_WINDOW,i]
                Y_tensor_new[current_row,i] = result_i
            X_tensor_new[current_row] = x_window
            current_row += 1
    X_tensor_new = X_tensor_new[:current_row,:,:]
    Y_tensor_new = Y_tensor_new[:current_row,:]

    return X_tensor_new, Y_tensor_new

In [85]:
x_train, y_train = make_3d_tensor_slices(x, y, lengths)

In [88]:
x_train[0][0]

array([ 1.90000000e+01,  1.00000000e+00,  1.00000000e+00, -1.01268629e+00,
        1.00000000e+00,  0.00000000e+00, -1.18751184e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  1.00000000e+00,  3.29113924e-02,
       -1.01268629e+00,  1.00000000e+00,  8.33333333e-01, -1.43949972e+00,
        1.00000000e+00,  1.00000000e+00, -1.01268629e+00,  0.00000000e+00,
        1.00000000e+00, -1.17538923e+00,  1.00000000e+00,  3.33333333e-01,
       -1.43166973e+00,  1.00000000e+00,  4.55696203e-01, -1.01268629e+00,
        1.00000000e+00,  5.42857143e-01, -1.48383086e+00,  1.00000000e+00,
        1.15789548e-01, -1.37795563e+00,  1.00000000e+00,  4.18803419e-01,
       -1.10417442e+00,  0.00000000e+00,  0.00000000e+00,  4.03067916e-01,
       -1.07015113e+00,  0.00000000e+00,  3.27608404e-01, -1.07015113e+00,
        0.00000000e+00,  5.72327039e-01, -1.44425672e+00,  0.00000000e+00,
        2.47064394e-01, -6.51903717e-01,  1.00000000e+00,  8.43750000e-01,
       -1.36385963e+00,  

In [80]:
y_train

array([[1., 1., 0., ..., 1., 0., 0.],
       [1., 1., 0., ..., 1., 0., 0.],
       [1., 1., 0., ..., 1., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 1.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])