In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys, os, pickle, utils, math, tqdm

from datetime import timedelta
#from utils import baseline_SCr

if os.getcwd()[-4:] == "code":
    os.chdir('../')

icu = './data/mimic-iv-2.2-parquet/icu/'
hosp = './data/mimic-iv-2.2-parquet/hosp/'

pd.set_option('mode.chained_assignment',  None) # 경고 off

In [2]:
df = pd.read_parquet('./data/resample/resample_label.parquet')
len(df.stay_id.unique())

73181

In [4]:
df.columns

Index(['subject_id', 'hadm_id', 'stay_id', 'charttime', 'HR', 'SBP', 'DBP',
       'MAP', 'temp', 'RR', 'CVP', 'SaO2', 'FiO2', 'Alb', 'Alk_Phos', 'AG',
       'BUN', 'Ca', 'CK', 'D_Bil', 'Glu', 'HCT', 'INR', 'PH', 'PHOS',
       'Platelet', 'Cl', 'SCr', 'Na', 'Potassium', 'T_Bil', 'WBC', 'Gl', 'Mg',
       'Ca_ion', 'HCO3', 'AST', 'ALT', 'PTT', 'baseexcess', 'lactate', 'PaO2',
       'PaCO2', 'gender', 'WHITE', 'BLACK', 'HISPANIC OR LATINO', 'ASIAN',
       'MULTIPLE RACE', 'OTHER', 'UNKNOWN', 'weight', 'age', 'LD', 'DH', 'HYP',
       'CKD', 'MI', 'DM', 'VD', 'CHF', 'COPD', 'baseline_SCr', 'MDRD', 'uo',
       'SOFA', 'AKI_UO', 'AKI_SCr', 'AKI', 'ventilation', 'fluid', 'vaso_equ',
       'Cephalosporins', 'Vancomycin', 'Betalactam_comb', 'Metronidazole',
       'Carbapenems', 'Penicillins', 'Fluoroquinolones', 'Others',
       'count_category', 'SCr/baseline_SCr', 'delta_SCr', 'BUN/SCr', 'dead',
       'AKI_stage3'],
      dtype='object')

In [3]:
df.sort_values(by=['subject_id','hadm_id','stay_id','charttime'],inplace=True)
df['step'] = df.groupby('stay_id').cumcount()

In [4]:
# 24시간 이내 퇴원
df = df.groupby('stay_id').filter(lambda x:len(x) > 24)
len(df.stay_id.unique())

57734

In [5]:
# 24시간 이내 AKI 발생

# 각 stay_id에 대해 처음 24개 행을 선택
first_24 = df.groupby('stay_id').head(24)[['step','subject_id','hadm_id','stay_id','charttime','uo','AKI_UO','SCr','baseline_SCr','AKI_SCr','AKI','dead']]

# 'AKI' 지표가 1 이상인 행을 포함하는 stay_id를 제외
filtered_df = first_24.groupby('stay_id').filter(lambda x: not (x['AKI'] >= 1).any())
filtered_df = filtered_df.groupby('stay_id').filter(lambda x: not (x['dead'] >= 1).any())

# 원래 DataFrame에서 필터링된 stay_id만 선택
df = df[df['stay_id'].isin(filtered_df['stay_id'].unique())]

len(df.stay_id.unique())

15721

In [6]:
# 각 stay_id별로 count_category가 3 이상인 경우의 개수 계산
count_3_or_more = df.groupby('stay_id')['count_category'].transform(lambda x: (x >= 3).sum())

# count_category가 3 이상인 경우가 없는 stay_id만 필터링
df = df[count_3_or_more < 1]
len(df.stay_id.unique())

15502

In [7]:
df.dropna(subset=['HR','SBP','DBP','MAP','temp','RR','weight'],inplace=True)
len(df.stay_id.unique())

15425

In [8]:
# Action
df.loc[df['count_category']==0, 'action'] = 0
df.loc[(df['count_category']==1)&(df['Cephalosporins']==1), 'action'] = 1
df.loc[(df['count_category']==1)&(df['Vancomycin']==1), 'action'] = 2
df.loc[(df['count_category']==1)&(df['Betalactam_comb']==1), 'action'] = 3
df.loc[(df['count_category']==1)&(df['Metronidazole']==1), 'action'] = 4
df.loc[(df['count_category']==1)&(df['Carbapenems']==1), 'action'] = 5
df.loc[(df['count_category']==1)&(df['Penicillins']==1), 'action'] = 6
df.loc[(df['count_category']==1)&(df['Fluoroquinolones']==1), 'action'] = 7
df.loc[(df['count_category']==1)&(df['Others']==1), 'action'] = 8

df.loc[(df['count_category']==2), 'action'] = 11
df.loc[(df['count_category']==2)&(df['Cephalosporins']==1)&(df['Vancomycin']==1), 'action'] = 9
df.loc[(df['count_category']==2)&(df['Betalactam_comb']==1)&(df['Vancomycin']==1), 'action'] = 10

In [9]:
df.action.value_counts()

action
0.0     1263816
1.0       30444
2.0       26653
3.0       13462
4.0        8091
6.0        7367
5.0        7308
8.0        6220
7.0        4711
11.0       3776
9.0        2066
10.0        905
Name: count, dtype: int64

In [10]:
len(df)

1374819

In [11]:
df['discharge'] = 0

def process_group(group):
    if group['dead'].sum() == 0:  # dead 열의 합이 0이면 모든 값이 0임
        group.iloc[-1, group.columns.get_loc('discharge')] = 1  # 마지막 행의 discharge 값을 1로 설정
    else:
        first_dead_index = group['dead'].idxmax()  # dead 열에서 1이 처음으로 나타나는 인덱스
        group = group.loc[:first_dead_index]  # 첫 번째 dead 이후의 행들을 제거
    return group

# stay_id별로 그룹화하고 각 그룹에 대해 process_group 함수 적용
df = df.groupby('stay_id').apply(process_group).reset_index(drop=True)

  df = df.groupby('stay_id').apply(process_group).reset_index(drop=True)


In [14]:
'''dead = df.pop('dead')
discharge = df.pop('discharge')
AKI_stage3 = df.pop('AKI_stage3')'''

In [12]:
df['traj'] = pd.factorize(df['stay_id'])[0]

In [13]:
df = df[['traj','step', 'age', 'gender','weight',
       'WHITE', 'BLACK', 'HISPANIC OR LATINO', 'ASIAN','MULTIPLE RACE', 'OTHER', 'UNKNOWN',
       'LD', 'DH', 'HYP', 'CKD', 'MI', 'DM', 'VD', 'CHF', 'COPD', 'baseline_SCr',
       'HR', 'SBP', 'DBP', 'MAP', 'temp', 'RR', 'CVP', 'SaO2', 'FiO2', 'PaO2','PaCO2',
       'Alb', 'Alk_Phos', 'AG',
       'BUN', 'Ca', 'CK', 'D_Bil', 'Glu', 'HCT', 'INR', 'PH', 'PHOS',
       'Platelet', 'Cl', 'SCr', 'Na', 'Potassium', 'T_Bil', 'WBC', 'Gl', 'Mg',
       'Ca_ion', 'HCO3', 'AST', 'ALT', 'PTT', 'baseexcess', 'lactate', 
       'uo', 'SOFA', 'AKI_UO', 'AKI_SCr', 'AKI', 'ventilation', 'fluid', 'vaso_equ',
       'SCr/baseline_SCr', 'delta_SCr', 'BUN/SCr',
       'action','dead','discharge','AKI_stage3']]

In [36]:
def make_train_val_test_split(df, train_frac=0.75, val_frac=0.05):
    all_traj = df['traj'].unique()
    all_AKI = []
    all_dead = []
    for traj in all_traj:
        aki = df[df['traj'] == traj]['AKI_stage3'].sum()
        dead = df[df['traj'] == traj]['dead'].sum()
        all_AKI.append(aki)
        all_dead.append(dead)
    dead_aki    = [x for x in range(len(all_traj)) if (all_AKI[x] > 0)&(all_dead[x] > 0)]
    dead_nonaki = [x for x in range(len(all_traj)) if (all_AKI[x] == 0)&(all_dead[x] > 0)]
    surv_aki    = [x for x in range(len(all_traj)) if (all_AKI[x] > 0)&(all_dead[x] == 0)]
    surv_nonaki = [x for x in range(len(all_traj)) if (all_AKI[x] == 0)&(all_dead[x] == 0)]

    print("dead_aki:",len(dead_aki),"dead_nonaki:",len(dead_nonaki),"surv_aki:",len(surv_aki),"surv_nonaki:",len(surv_nonaki))

    np.random.shuffle(dead_aki)
    np.random.shuffle(dead_nonaki)
    np.random.shuffle(surv_aki)
    np.random.shuffle(surv_nonaki)

    train_dead_aki_index    = int(np.round(train_frac*len(dead_aki),0))
    train_dead_nonaki_index = int(np.round(train_frac*len(dead_nonaki),0))
    train_surv_aki_index    = int(np.round(train_frac*len(surv_aki),0))
    train_surv_nonaki_index = int(np.round(train_frac*len(surv_nonaki),0))

    val_dead_aki_index      = int(np.round(val_frac*len(dead_aki),0)) + train_dead_aki_index
    val_dead_nonaki_index   = int(np.round(val_frac*len(dead_nonaki),0)) + train_dead_nonaki_index
    val_surv_aki_index      = int(np.round(val_frac*len(surv_aki),0)) + train_surv_aki_index
    val_surv_nonaki_index   = int(np.round(val_frac*len(surv_nonaki),0)) + train_surv_nonaki_index

    train_traj = dead_aki[:train_dead_aki_index]
    train_traj.extend(dead_nonaki[:train_dead_nonaki_index])
    train_traj.extend(surv_aki[:train_surv_aki_index])
    train_traj.extend(surv_nonaki[:train_surv_nonaki_index])

    val_traj = dead_aki[train_dead_aki_index:val_dead_aki_index]
    val_traj.extend(dead_nonaki[train_dead_nonaki_index:val_dead_nonaki_index])
    val_traj.extend(surv_aki[train_surv_aki_index:val_surv_aki_index])
    val_traj.extend(surv_nonaki[train_surv_nonaki_index:val_surv_nonaki_index])

    test_traj = dead_aki[val_dead_aki_index:]
    test_traj.extend(dead_nonaki[val_dead_nonaki_index:])
    test_traj.extend(surv_aki[val_surv_aki_index:])
    test_traj.extend(surv_nonaki[val_surv_nonaki_index:])

    train_df = df[df['traj'].isin(train_traj)]
    val_df   = df[df['traj'].isin(val_traj)]
    test_df  = df[df['traj'].isin(test_traj)]

    print('==============================')
    print("train_df:",len(train_df), "val_df:",len(val_df), "test_df:",len(test_df))
    print("train_df.traj.unique():",len(train_df.traj.unique()), "val_df.traj.unique():",len(val_df.traj.unique()), "test_df.traj.unique():",len(test_df.traj.unique()))

    
    train_df.to_parquet('./data/train.parquet')
    val_df.to_parquet('./data/val.parquet')
    test_df.to_parquet('./data/test.parquet')

    return train_df, val_df, test_df

train_df, val_df, test_df = make_train_val_test_split(df)

dead_aki: 175 dead_nonaki: 338 surv_aki: 1183 surv_nonaki: 13729
len(train_df): 1031688 len(val_df): 68539 len(test_df): 270899
len(train_df.traj.unique()): 11569 len(val_df.traj.unique()): 771 len(test_df.traj.unique()): 3085


In [49]:
def change_columns(df):
    for i,idx in enumerate(df.columns):
        if i < 2 : pass
        elif i >= 2 and i < len(df.columns) - 4 : df.rename(columns={idx:'s:'+idx},inplace=True)
        elif idx == 'action': df.rename(columns={idx:'a:'+idx},inplace=True)
        else : df.rename(columns={idx:'r:'+idx},inplace=True)
    return df

train_df = change_columns(train_df)
val_df = change_columns(val_df)
test_df = change_columns(test_df)

train_df.to_parquet('./data/train.parquet')
val_df.to_parquet('./data/val.parquet')
test_df.to_parquet('./data/test.parquet')

In [52]:
test_df.columns

Index(['traj', 'step', 's:age', 's:gender', 's:weight', 's:WHITE', 's:BLACK',
       's:HISPANIC OR LATINO', 's:ASIAN', 's:MULTIPLE RACE', 's:OTHER',
       's:UNKNOWN', 's:LD', 's:DH', 's:HYP', 's:CKD', 's:MI', 's:DM', 's:VD',
       's:CHF', 's:COPD', 's:baseline_SCr', 's:HR', 's:SBP', 's:DBP', 's:MAP',
       's:temp', 's:RR', 's:CVP', 's:SaO2', 's:FiO2', 's:PaO2', 's:PaCO2',
       's:Alb', 's:Alk_Phos', 's:AG', 's:BUN', 's:Ca', 's:CK', 's:D_Bil',
       's:Glu', 's:HCT', 's:INR', 's:PH', 's:PHOS', 's:Platelet', 's:Cl',
       's:SCr', 's:Na', 's:Potassium', 's:T_Bil', 's:WBC', 's:Gl', 's:Mg',
       's:Ca_ion', 's:HCO3', 's:AST', 's:ALT', 's:PTT', 's:baseexcess',
       's:lactate', 's:uo', 's:SOFA', 's:AKI_UO', 's:AKI_SCr', 's:AKI',
       's:ventilation', 's:fluid', 's:vaso_equ', 's:SCr/baseline_SCr',
       's:delta_SCr', 's:BUN/SCr', 'a:action', 'r:dead', 'r:discharge',
       'r:AKI_stage3'],
      dtype='object')

In [46]:
train_df.action.value_counts()/len(train_df)

action
0.0     947028
1.0      23098
2.0      20284
3.0      10371
4.0       6084
6.0       5618
5.0       5605
8.0       4867
7.0       3571
11.0      2913
9.0       1563
10.0       686
Name: count, dtype: int64

In [47]:
val_df.action.value_counts()/len(val_df)*100

action
0.0     63264
1.0      1434
2.0      1305
3.0       631
4.0       384
8.0       339
5.0       315
6.0       308
7.0       281
11.0      140
9.0        84
10.0       54
Name: count, dtype: int64

In [48]:
test_df.action.value_counts()

action
0.0     250001
1.0       5873
2.0       5013
3.0       2428
4.0       1619
6.0       1430
5.0       1381
8.0       1004
7.0        856
11.0       719
9.0        414
10.0       161
Name: count, dtype: int64