In [22]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys, os, pickle, utils, math, tqdm

from datetime import timedelta
#from utils import baseline_SCr

if os.getcwd()[-4:] == "code":
    os.chdir('../')

icu = './data/mimic-iv-2.2-parquet/icu/'
hosp = './data/mimic-iv-2.2-parquet/hosp/'

pd.set_option('mode.chained_assignment',  None) # 경고 off

In [24]:
df = pd.read_parquet('./data/resample/resample_label.parquet')
len(df.stay_id.unique())

73181

In [25]:
columnss = pd.DataFrame(df.columns)

In [3]:
df.sort_values(by=['subject_id','hadm_id','stay_id','charttime'],inplace=True)

In [26]:
# 24시간 이내 퇴원
df = df.groupby('stay_id').filter(lambda x:len(x) > 24)
len(df.stay_id.unique())

57734

In [27]:
# 24시간 이내 AKI 발생

# 각 stay_id에 대해 처음 24개 행을 선택
first_24 = df.groupby('stay_id').head(24)

# 'AKI' 지표가 1 이상인 행을 포함하는 stay_id를 제외
filtered_df = first_24.groupby('stay_id').filter(lambda x: not (x['AKI'] >= 1).any())
filtered_df = filtered_df.groupby('stay_id').filter(lambda x: not (x['dead'] >= 1).any())

# 원래 DataFrame에서 필터링된 stay_id만 선택
df = df[df['stay_id'].isin(filtered_df['stay_id'].unique())]

len(df.stay_id.unique())

15459

In [28]:
# 각 stay_id별로 count_category가 3 이상인 경우의 개수 계산
count_3_or_more = df.groupby('stay_id')['category_count'].transform(lambda x: (x >= 3).sum())

# count_category가 3 이상인 경우가 없는 stay_id만 필터링
df = df[count_3_or_more < 1]
len(df.stay_id.unique())

15246

In [29]:
df.dropna(subset=['HR','SBP','DBP','MAP','temp','RR','weight'],inplace=True)
len(df.stay_id.unique())

15234

In [30]:
# Action
df.loc[df['category_count']==0, 'action'] = 0
df.loc[(df['category_count']==1)&(df['Cephalosporins']==1), 'action'] = 1
df.loc[(df['category_count']==1)&(df['Vancomycin']==1), 'action'] = 2
df.loc[(df['category_count']==1)&(df['Betalactam_comb']==1), 'action'] = 3
df.loc[(df['category_count']==1)&(df['Metronidazole']==1), 'action'] = 4
df.loc[(df['category_count']==1)&(df['Carbapenems']==1), 'action'] = 5
df.loc[(df['category_count']==1)&(df['Penicillins']==1), 'action'] = 6
df.loc[(df['category_count']==1)&(df['Fluoroquinolones']==1), 'action'] = 7
df.loc[(df['category_count']==1)&(df['Others']==1), 'action'] = 8

df.loc[(df['category_count']==2), 'action'] = 11
df.loc[(df['category_count']==2)&(df['Cephalosporins']==1)&(df['Vancomycin']==1), 'action'] = 9
df.loc[(df['category_count']==2)&(df['Betalactam_comb']==1)&(df['Vancomycin']==1), 'action'] = 10

In [31]:
df.action.value_counts()

action
0.0     1243904
1.0       30142
2.0       26248
3.0       13116
4.0        7815
6.0        7290
5.0        7152
8.0        6107
7.0        4638
11.0       3680
9.0        2034
10.0        892
Name: count, dtype: int64

In [32]:
df['discharge'] = 0

def process_group(group):
    if group['dead'].sum() == 0:  # dead 열의 합이 0이면 모든 값이 0임
        group.iloc[-1, group.columns.get_loc('discharge')] = 1  # 마지막 행의 discharge 값을 1로 설정
    else:
        first_dead_index = group['dead'].idxmax()  # dead 열에서 1이 처음으로 나타나는 인덱스
        group = group.loc[:first_dead_index]  # 첫 번째 dead 이후의 행들을 제거
    return group

# stay_id별로 그룹화하고 각 그룹에 대해 process_group 함수 적용
df = df.groupby('stay_id').apply(process_group).reset_index(drop=True)

  df = df.groupby('stay_id').apply(process_group).reset_index(drop=True)


In [33]:
df['presense_SOFA'] = 1
df.loc[df['SOFA'].isna(),'presense_SOFA'] = 0

df['presense_BUN/SCr'] = 1
df.loc[df['BUN/SCr'].isna(),'presense_BUN/SCr'] = 0

df.loc[(df['presense_BUN']==0)|(df['presense_SCr']==0),'presense_BUN/SCr'] = 0
df.loc[df['presense_BUN/SCr']==0,'BUN/SCr']=0

df.fillna(0,inplace=True)

In [34]:
df.loc[df['action']==0, 'action_0'] = 1
df.loc[df['action']==1, 'action_1'] = 1
df.loc[df['action']==2, 'action_2'] = 1
df.loc[df['action']==3, 'action_3'] = 1
df.loc[df['action']==4, 'action_4'] = 1
df.loc[df['action']==5, 'action_5'] = 1
df.loc[df['action']==6, 'action_6'] = 1
df.loc[df['action']==7, 'action_7'] = 1
df.loc[df['action']==8, 'action_8'] = 1
df.loc[df['action']==9, 'action_9'] = 1
df.loc[df['action']==10, 'action_10'] = 1
df.loc[df['action']==11, 'action_11'] = 1

In [35]:
from tqdm import tqdm
def zeropadding(df):
    rt = []
    for i in tqdm(df.stay_id.unique()):
        tmp = df[df['stay_id']==i]
        endtime = tmp.charttime.min() - pd.Timedelta(hours=1)
        starttime = endtime - pd.Timedelta(hours=23)
        timestamp_range = pd.date_range(start=starttime,end=endtime,freq='h')

        empty_df = pd.DataFrame({'stay_id':i, 'charttime':timestamp_range})
        tmp = pd.concat([tmp,empty_df])
        tmp.sort_values(by='charttime',inplace=True)
        rt.append(tmp)
    rt = pd.concat(rt)
    rt.reset_index(inplace=True)
    return rt
df = zeropadding(df)

  0%|          | 0/15234 [00:00<?, ?it/s]

100%|██████████| 15234/15234 [01:10<00:00, 216.23it/s]


In [36]:
df['traj'] = pd.factorize(df['stay_id'])[0]
df.sort_values(by=['traj','charttime'],inplace=True)
df['step'] = df.groupby('stay_id').cumcount()

In [37]:
adas2=pd.DataFrame(df.columns)

In [40]:
df = df[['traj','step', 'age', 'gender','weight',
       'WHITE', 'BLACK', 'HISPANIC OR LATINO', 'ASIAN', 'OTHER', 'UNKNOWN',
       'LD', 'DH', 'HYP', 'CKD', 'MI', 'DM', 'VD', 'CHF', 'COPD', 'baseline_SCr',
       'HR', 'SBP', 'DBP', 'MAP', 'temp', 'RR', 'CVP', 'SaO2', 'FiO2',

       'Alb', 'Alk_Phos', 'AG', 'BUN', 'Ca', 'CK', 'D_Bil', 'Glu', 'HCT', 'INR', 'PH', 'PHOS',
       'Platelet', 'Cl', 'SCr', 'Na', 'Potassium', 'T_Bil', 'WBC', 'Gl', 'Mg',
       'Ca_ion', 'HCO3', 'AST', 'ALT', 'PTT', 'baseexcess', 'lactate','PaO2','PaCO2',
       
       'presense_HR', 'presense_SBP', 'presense_DBP', 'presense_MAP', 'presense_temp', 'presense_RR', 'presense_CVP', 'presense_SaO2', 'presense_FiO2',
       'presense_Alb', 'presense_Alk_Phos', 'presense_AG', 'presense_BUN', 'presense_Ca', 'presense_CK', 'presense_D_Bil', 'presense_Glu', 'presense_HCT', 'presense_INR', 'presense_PH', 'presense_PHOS',
       'presense_Platelet', 'presense_Cl', 'presense_SCr', 'presense_Na', 'presense_Potassium', 'presense_T_Bil', 'presense_WBC', 'presense_Gl', 'presense_Mg',
       'presense_Ca_ion', 'presense_HCO3', 'presense_AST', 'presense_ALT', 'presense_PTT', 'presense_baseexcess', 'presense_lactate','presense_PaO2','presense_PaCO2',
       'presense_SOFA', 'presense_BUN/SCr','presense_AKI_UO','presense_AKI_SCr','presense_AKI',
              
       'uo', 'SOFA', 'AKI_UO', 'AKI_SCr', 'AKI', 'ventilation', 'fluid', 'vaso_equ',
       'SCr/baseline_SCr', 'delta_SCr', 'BUN/SCr', 
       'action_0','action_1','action_2','action_3','action_4','action_5','action_6','action_7','action_8','action_9','action_10','action_11',
       'action','dead','discharge','AKI_stage3']]

In [42]:
col =['age', 'gender','weight',
       'WHITE', 'BLACK', 'HISPANIC OR LATINO', 'ASIAN', 'OTHER', 'UNKNOWN',
       'LD', 'DH', 'HYP', 'CKD', 'MI', 'DM', 'VD', 'CHF', 'COPD', 'baseline_SCr']

In [43]:
df[col] = df.groupby('traj')[col].bfill()

In [44]:
len(df.traj.unique())

15234

In [45]:
df.fillna(0,inplace=True)

In [46]:
def make_train_val_test_split(df, train_frac=0.75, val_frac=0.05):
    all_traj = df['traj'].unique()
    all_AKI = []
    all_dead = []
    for traj in all_traj:
        aki = df[df['traj'] == traj]['AKI_stage3'].sum()
        dead = df[df['traj'] == traj]['dead'].sum()
        all_AKI.append(aki)
        all_dead.append(dead)
    dead_aki    = [x for x in range(len(all_traj)) if (all_AKI[x] > 0)&(all_dead[x] > 0)]
    dead_nonaki = [x for x in range(len(all_traj)) if (all_AKI[x] == 0)&(all_dead[x] > 0)]
    surv_aki    = [x for x in range(len(all_traj)) if (all_AKI[x] > 0)&(all_dead[x] == 0)]
    surv_nonaki = [x for x in range(len(all_traj)) if (all_AKI[x] == 0)&(all_dead[x] == 0)]

    print("dead_aki:",len(dead_aki),"dead_nonaki:",len(dead_nonaki),"surv_aki:",len(surv_aki),"surv_nonaki:",len(surv_nonaki))

    np.random.shuffle(dead_aki)
    np.random.shuffle(dead_nonaki)
    np.random.shuffle(surv_aki)
    np.random.shuffle(surv_nonaki)

    train_dead_aki_index    = int(np.round(train_frac*len(dead_aki),0))
    train_dead_nonaki_index = int(np.round(train_frac*len(dead_nonaki),0))
    train_surv_aki_index    = int(np.round(train_frac*len(surv_aki),0))
    train_surv_nonaki_index = int(np.round(train_frac*len(surv_nonaki),0))

    val_dead_aki_index      = int(np.round(val_frac*len(dead_aki),0)) + train_dead_aki_index
    val_dead_nonaki_index   = int(np.round(val_frac*len(dead_nonaki),0)) + train_dead_nonaki_index
    val_surv_aki_index      = int(np.round(val_frac*len(surv_aki),0)) + train_surv_aki_index
    val_surv_nonaki_index   = int(np.round(val_frac*len(surv_nonaki),0)) + train_surv_nonaki_index

    train_traj = dead_aki[:train_dead_aki_index]
    train_traj.extend(dead_nonaki[:train_dead_nonaki_index])
    train_traj.extend(surv_aki[:train_surv_aki_index])
    train_traj.extend(surv_nonaki[:train_surv_nonaki_index])

    val_traj = dead_aki[train_dead_aki_index:val_dead_aki_index]
    val_traj.extend(dead_nonaki[train_dead_nonaki_index:val_dead_nonaki_index])
    val_traj.extend(surv_aki[train_surv_aki_index:val_surv_aki_index])
    val_traj.extend(surv_nonaki[train_surv_nonaki_index:val_surv_nonaki_index])

    test_traj = dead_aki[val_dead_aki_index:]
    test_traj.extend(dead_nonaki[val_dead_nonaki_index:])
    test_traj.extend(surv_aki[val_surv_aki_index:])
    test_traj.extend(surv_nonaki[val_surv_nonaki_index:])

    train_df = df[df['traj'].isin(train_traj)]
    val_df   = df[df['traj'].isin(val_traj)]
    test_df  = df[df['traj'].isin(test_traj)]

    print('==============================')
    print("train_df:",len(train_df), "val_df:",len(val_df), "test_df:",len(test_df))
    print("train_df.traj.unique():",len(train_df.traj.unique()), "val_df.traj.unique():",len(val_df.traj.unique()), "test_df.traj.unique():",len(test_df.traj.unique()))

    
    #train_df.to_parquet('./data/train.parquet')
    #val_df.to_parquet('./data/val.parquet')
    #test_df.to_parquet('./data/test.parquet')

    return train_df, val_df, test_df

train_df, val_df, test_df = make_train_val_test_split(df)

dead_aki: 172 dead_nonaki: 337 surv_aki: 1161 surv_nonaki: 13564
train_df: 1278324 val_df: 88932 test_df: 347715
train_df.traj.unique(): 11426 val_df.traj.unique(): 762 test_df.traj.unique(): 3046


In [47]:
def change_columns(df):
    for i,idx in enumerate(df.columns):
        if i < 2 : 
            pass
        elif i >= 2 and i < len(df.columns) - 4 : 
            df[idx] = (df[idx]-df[idx].min())/(df[idx].max()-df[idx].min())
            df.rename(columns={idx:'s:'+idx},inplace=True)            
        elif idx == 'action': 
            df.rename(columns={idx:'a:'+idx},inplace=True)
        else : 
            df.rename(columns={idx:'r:'+idx},inplace=True)
    return df

train_df = change_columns(train_df)
val_df = change_columns(val_df)
test_df = change_columns(test_df)

#train_df.to_parquet('./code/train.parquet')
#val_df.to_parquet('./code/val.parquet')
#test_df.to_parquet('./code/test.parquet')

In [48]:
train_df.isna().sum().sum()
val_df.isna().sum().sum()
test_df.isna().sum().sum()

0

In [54]:
all = pd.DataFrame([x for x in train_df.columns])

In [52]:
sa = pd.DataFrame([x for x in train_df.columns if x[:2]=='s:'])

In [55]:
len([x for x in train_df.columns if x[:2]=='s:'])

125

In [49]:
train_df.to_parquet('./code/train.parquet')
val_df.to_parquet('./code/val.parquet')
test_df.to_parquet('./code/test.parquet')

In [None]:
train_df.isna().sum()