In [1]:
import numpy as np
import pandas as pd
import tqdm
import argparse
import os
from ai_clinician.modeling.normalization import DataNormalization
from ai_clinician.preprocessing.utils import load_csv
from ai_clinician.preprocessing.columns import *
from ai_clinician.modeling.columns import *
from scipy.stats import zscore
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

tqdm.tqdm.pandas()

def save_data_files(dir, MIMICraw, MIMICzs, metadata):
    MIMICraw.to_csv(os.path.join(dir, "MIMICraw.csv"), index=False)
    MIMICzs.to_csv(os.path.join(dir, "MIMICzs.csv"), index=False)
    metadata.to_csv(os.path.join(dir, "metadata.csv"), index=False)



In [2]:
pwd

'/home/lkapral/RRT_mimic_iv'

In [3]:
def create_args():
    parser = argparse.ArgumentParser(description=(
        'Generates a train/test split of the MIMIC-IV dataset, and generates files labeled '
        '{train|test}/MIMICraw.npy and {train|test}/MIMICzs.npy.'
    ))
    parser.add_argument('input', type=str,
                        help='Data directory (should contain mimic_dataset.csv and aki_cohort.csv)')
    parser.add_argument('output', type=str,
                        help='Directory in which to output')
    parser.add_argument('--train-size', dest='train_size', type=float, default=0.7,
                        help='Proportion of data to use in training (default 0.7)')
    parser.add_argument('--outcome', dest='outcome_col', type=str, default='died_in_hosp',
                        help='Name of column to use for outcomes (probably "died_in_hosp" [default] or "morta_90")')
    
    # Simulate input arguments as if they were passed from the command line
    simulated_input = [
        '/home/lkapral/RRT_mimic_iv/data/mimic',    # Replace with your actual input directory
        '/home/lkapral/RRT_mimic_iv/data/model',   # Replace with your actual output directory
        '--train-size', '0.7',
        '--outcome', 'morta_90'
    ]
    return parser.parse_args(simulated_input)

# Create args object
args = create_args()

in_dir = args.input
out_dir = args.output
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

# Find sepsis cohort in the mimic dataset
mdp_data = load_csv(os.path.join(in_dir, "mimic_dataset.csv"))
aki_cohort = load_csv(os.path.join(in_dir, "aki_cohort.csv"))

print(list(mdp_data.columns))


MIMICtable = mdp_data[mdp_data[C_ICUSTAYID].isin(aki_cohort[C_ICUSTAYID])].reset_index(drop=True)
assert args.outcome_col in MIMICtable.columns, "Outcome column '{}' not found in MIMICtable".format(args.outcome_col)



# Define RRT-related columns
rrt_cols = [
    'Ultrafiltrate_Output',
    'Blood_Flow',
    'Hourly_Patient_Fluid_Removal',
    'Dialysate_Rate',
    'Hemodialysis_Output',  # Ensure the column name matches your DataFrame
    'Citrate',
    'Prefilter_Replacement_Rate',
    'Postfilter_Replacement_Rate'
]





['bloc', 'icustayid', 'timestep', 'gender', 'age', 'elixhauser', 're_admission', 'died_in_hosp', 'died_within_48h_of_out_time', 'morta_90', 'delay_end_of_record_and_discharge_or_death', 'Height_cm', 'Weight_kg', 'GCS', 'RASS', 'HR', 'SysBP', 'MeanBP', 'DiaBP', 'RR', 'SpO2', 'Temp_C', 'Temp_F', 'CVP', 'PAPsys', 'PAPmean', 'PAPdia', 'CI', 'SVR', 'Interface', 'FiO2_100', 'FiO2_1', 'O2flow', 'PEEP', 'TidalVolume', 'MinuteVentil', 'PAWmean', 'PAWpeak', 'PAWplateau', 'Respiratory_Rate', 'Ultrafiltrate_Output', 'Blood_Flow', 'Hourly_Patient_Fluid_Removal', 'Dialysate_Rate', 'APACHEII_Renal_Failure', 'Hemodialysis_Output', 'Citrate', 'Prefilter_Replacement_Rate', 'Postfilter_Replacement_Rate', 'Potassium', 'Sodium', 'Chloride', 'Glucose', 'BUN', 'Creatinine', 'Magnesium', 'Calcium', 'Ionised_Ca', 'CO2_mEqL', 'SGOT', 'SGPT', 'Total_bili', 'Direct_bili', 'Total_protein', 'Albumin', 'Troponin', 'CRP', 'Hb', 'Ht', 'RBC_count', 'WBC_count', 'Platelets_count', 'PTT', 'PT', 'ACT', 'INR', 'Arterial_pH

In [4]:
# Create 'action' column
rrt_actions = (~MIMICtable[rrt_cols].isna() & (MIMICtable[rrt_cols] != 0)).any(axis=1)
MIMICtable['action'] = rrt_actions.astype(int)

# Actions array
actions = MIMICtable['action'].values

np.seterr(divide='ignore', invalid='ignore')

{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}

In [5]:
pd.set_option('display.max_columns', None)

In [6]:
icu_stays = pd.read_csv('/home/lkapral/RRT_mimic_iv/data/icustays.csv')
icd_diagnoses = pd.read_csv('/home/lkapral/RRT_mimic_iv/data/d_icd_diagnoses.csv')
diagnose_icd = pd.read_csv('/home/lkapral/RRT_mimic_iv/data/diagnoses_icd.csv')
exclude_idc = pd.read_csv('/home/lkapral/RRT_mimic_iv/data/exclusion.csv')

exclude_list = exclude_idc['icd_code'].to_list()

merged_df = MIMICtable.merge(icu_stays[['stay_id', 'subject_id']], left_on='icustayid', right_on='stay_id', how='left')

# Filter rows where icd_code is in exclude_list
excluded_rows = diagnose_icd[diagnose_icd['icd_code'].isin(exclude_list)]
# Get the unique hadm_id values from these rows
excluded_subject_ids = excluded_rows['subject_id'].unique().tolist()

print('Number of Patients with kidney issues. ', len(excluded_subject_ids))

print('Number of patients before exclusion:', len(merged_df['icustayid'].unique()))

print('Number of patients with RRT before exclusion:' , len(merged_df[merged_df['action']>0]['icustayid'].unique()))

merged_df = merged_df[~merged_df['subject_id'].isin(excluded_subject_ids)]

print('Number of patients after exclusion:', len(merged_df['icustayid'].unique()))

print('Number of patients with RRT after exclusion:' , len(merged_df[merged_df['action']>0]['icustayid'].unique()))

merged_df.drop(columns=['stay_id', 'subject_id'], inplace=True)

MIMICtable = merged_df

Number of Patients with kidney issues.  5055
Number of patients before exclusion: 59851
Number of patients with RRT before exclusion: 4002
Number of patients after exclusion: 54859
Number of patients with RRT after exclusion: 2055


In [7]:
'''
Number of Patients with kidney issues.  5055
Number of patients before exclusion: 59862
Number of patients with RRT before exclusion: 4002
Number of patients after exclusion: 54870
Number of patients with RRT after exclusion: 2055
'''

'\nNumber of Patients with kidney issues.  5055\nNumber of patients before exclusion: 59862\nNumber of patients with RRT before exclusion: 4002\nNumber of patients after exclusion: 54870\nNumber of patients with RRT after exclusion: 2055\n'

In [8]:


import pandas as pd

# Assuming MIMICtable is your original DataFrame

# 1. Store the original data types
original_dtypes = MIMICtable.dtypes.to_dict()

# 2. Create a 'day' column by dividing 'bloc' by 3 (ensure integer division if needed)
MIMICtable['day'] = MIMICtable['bloc'] // 3

# 3. Define the columns for different aggregation functions
sum_cols = [
    'input_total', 'input_step', 'output_total', 'output_step',
    'cumulated_balance', 'median_dose_vaso', 'max_dose_vaso'
]

max_cols = ['mechvent', 'extubated', 'action']

first_cols = ['gender', 'age', 'elixhauser', 're_admission', 'Height_cm', 'Weight_kg']

# 4. Identify columns to average (ensure they are numeric)
excluded_cols = set(sum_cols + max_cols + first_cols + ['icustayid', 'timestep', 'bloc', 'day'])
mean_cols = [col for col in MIMICtable.columns if col not in excluded_cols]

# 5. Create the aggregation dictionary
agg_dict = {col: 'sum' for col in sum_cols}
agg_dict.update({col: 'max' for col in max_cols})
agg_dict.update({col: 'first' for col in first_cols})
agg_dict.update({col: 'mean' for col in mean_cols})

# 6. Perform the groupby aggregation
# 6. Perform the groupby aggregation by encounterId and day
MIMICtable_agg = MIMICtable.groupby(['icustayid', 'day']).agg(agg_dict).reset_index()

# 7. Sort the aggregated DataFrame to ensure cumulative sums are computed in the right order
MIMICtable_agg = MIMICtable_agg.sort_values(by=['icustayid', 'day'])

# 8. Compute the cumulative sums for input_total and output_total using input_step and output_step
MIMICtable_agg['input_total'] = MIMICtable_agg.groupby('icustayid')['input_step'].cumsum()
MIMICtable_agg['output_total'] = MIMICtable_agg.groupby('icustayid')['output_step'].cumsum()

# 9. Calculate cumulated_balance as the sum of the cumulative totals
MIMICtable_agg['cumulated_balance'] = MIMICtable_agg['input_total'] + MIMICtable_agg['output_total']

# 10. (Optional) Restore the original data types if needed
for col in MIMICtable_agg.columns:
    if col in original_dtypes:
        original_dtype = original_dtypes[col]
        try:
            if pd.api.types.is_integer_dtype(original_dtype):
                MIMICtable_agg[col] = MIMICtable_agg[col].round().astype(original_dtype)
            else:
                MIMICtable_agg[col] = MIMICtable_agg[col].astype(original_dtype)
        except (ValueError, TypeError):
            print(f"Warning: Could not convert column '{col}' to {original_dtype}. Keeping the aggregated type.")

print(MIMICtable_agg.dtypes)

# 11. (Optional) Adjust 'bloc' if needed (here setting bloc as day + 1)
MIMICtable_agg['bloc'] = MIMICtable_agg['day'] + 1
MIMICtable_agg.drop(columns=['day'], inplace=True)

# Replace original DataFrame with the aggregated one if desired
MIMICtable = MIMICtable_agg



icustayid         int64
day               int64
input_total     float64
input_step      float64
output_total    float64
                 ...   
Insulin         float64
Shock_Index     float64
PaO2_FiO2       float64
SOFA              int64
SIRS              int64
Length: 121, dtype: object


  MIMICtable_agg['bloc'] = MIMICtable_agg['day'] + 1


In [9]:
for i, col in enumerate(MIMICtable.columns):
    print(i)
    print(col)
    print(MIMICtable[col].describe())
    print('-----------------------------------------')
    print()

0
icustayid
count    4.181470e+05
mean     3.498165e+07
std      2.882681e+06
min      3.000015e+07
25%      3.248068e+07
50%      3.495937e+07
75%      3.747017e+07
max      3.999986e+07
Name: icustayid, dtype: float64
-----------------------------------------

1
input_total
count    418147.000000
mean      10452.006046
std       27504.645963
min           0.000000
25%         333.689150
50%        2536.210000
75%        8099.928915
max      494685.920000
Name: input_total, dtype: float64
-----------------------------------------

2
input_step
count    418147.000000
mean       1350.605632
std        3399.446070
min           0.000000
25%           0.000000
50%         220.000000
75%         906.889500
max       45165.360000
Name: input_step, dtype: float64
-----------------------------------------

3
output_total
count    418147.000000
mean       5410.535414
std        7212.981596
min          -1.000000
25%         950.000000
50%        2685.000000
75%        6892.000000
max      2300

In [10]:
MIMICtable[MIMICtable['SvO2']>200]['SvO2']

300864     45523.25
300865    999999.00
Name: SvO2, dtype: float64

In [11]:
MIMICtable.loc[MIMICtable['Height_cm']>250, 'Height_cm'] = MIMICtable['Height_cm'].mean()

In [12]:
MIMICtable.loc[MIMICtable['Weight_kg']>400, 'Weight_kg'] = MIMICtable['Weight_kg'].mean( )

In [13]:
MIMICtable.loc[MIMICtable['O2flow']>100, 'O2flow'] = 100.

In [14]:
MIMICtable.loc[MIMICtable['SvO2']>100, 'SvO2'] = 100.

In [15]:
MIMICtable.loc[MIMICtable['paO2']<0, 'paO2'] = 0

In [16]:
MIMICtable.loc[MIMICtable['Hourly_Patient_Fluid_Removal']<0, 'Hourly_Patient_Fluid_Removal'] = 0

In [17]:
MIMICtable

Unnamed: 0,icustayid,input_total,input_step,output_total,output_step,cumulated_balance,median_dose_vaso,max_dose_vaso,mechvent,extubated,action,gender,age,elixhauser,re_admission,Height_cm,Weight_kg,died_in_hosp,died_within_48h_of_out_time,morta_90,delay_end_of_record_and_discharge_or_death,GCS,RASS,HR,SysBP,MeanBP,DiaBP,RR,SpO2,Temp_C,Temp_F,CVP,PAPsys,PAPmean,PAPdia,CI,SVR,Interface,FiO2_100,FiO2_1,O2flow,PEEP,TidalVolume,MinuteVentil,PAWmean,PAWpeak,PAWplateau,Respiratory_Rate,Ultrafiltrate_Output,Blood_Flow,Hourly_Patient_Fluid_Removal,Dialysate_Rate,APACHEII_Renal_Failure,Hemodialysis_Output,Citrate,Prefilter_Replacement_Rate,Postfilter_Replacement_Rate,Potassium,Sodium,Chloride,Glucose,BUN,Creatinine,Magnesium,Calcium,Ionised_Ca,CO2_mEqL,SGOT,SGPT,Total_bili,Direct_bili,Total_protein,Albumin,Troponin,CRP,Hb,Ht,RBC_count,WBC_count,Platelets_count,PTT,PT,ACT,INR,Arterial_pH,paO2,paCO2,Arterial_BE,Arterial_lactate,HCO3,ETCO2,SvO2,Anion_Gap,Ammonia,Fibrinogen,Absolute_Neutrophil_Count,Phosphorous,SaO2,Triglyceride,ScvO2,LDH,CK_MB,BNP,Iron,Thyroid_Stimulating_Hormone,Creatinine_Urine,Potassium_Urine,Sodium_Urine,Urea_Nitrogen_Urine,Creatinine_Clearance,T3,Gamma_Glutamyltransferase,Myoglobin,Heparin_LMW,Osmolality_Urine,Insulin,Shock_Index,PaO2_FiO2,SOFA,SIRS,bloc
0,30000153,0.00,0.00,0.0,0.0,0.00,0.000,0.000,0,,0,0,61,1,False,,70.0,0,,0,260.167,15.000000,2.200000,60.400000,119.000000,79.000000,59.000000,14.200000,99.400000,36.955500,98.060000,12.800000,,,,,,0.0,24.000000,0.240000,4.000000,5.000000,500.000000,6.660000,6.400000,11.000000,10.600000,22.000000,,,,,,,,,,4.000000,138.000000,101.000000,110.800000,41.000000,3.900000,2.30000,10.000000,1.100000,21.000000,319.000000,363.000000,13.400000,5.400000,,3.60,0.160000,,9.672960,28.900000,3.220000,18.000000,91.000000,28.900000,13.900000,,1.30000,7.360000,100.000000,38.000000,-4.000000,1.500000,21.000000,,,16.000000,,167.000000,,3.800000,,,,194.000000,10.000000,,,,,,,,,,,,,,,0.507563,416.667000,9,1,1
1,30000484,250.00,250.00,360.0,360.0,610.00,0.000,0.000,0,,0,0,92,4,True,163.0,68.5,0,0.0,1,180.000,14.500000,-0.500000,88.651550,111.950000,75.700000,57.575000,15.522750,97.000000,36.277800,97.300000,12.214300,,,,,,2.0,34.000000,0.340000,3.500000,7.500000,732.625000,9.900000,9.500000,13.000000,19.333350,16.866650,,,,,,,,,,4.000000,142.500000,102.500000,93.000000,15.000000,0.500000,2.05000,8.550000,1.100000,31.000000,30.500000,86.500000,1.000000,0.518200,,2.75,0.135000,56.5,10.750000,33.350000,3.135000,12.300000,63.000000,82.525650,14.800000,,1.30000,7.448000,77.900000,46.900000,-3.000000,1.110000,23.500000,,,11.000000,,349.000000,,2.750000,,,,522.500000,1.500000,,,,,,,,,,,,,,,0.813359,240.138900,6,2,1
2,30000484,1754.85,1504.85,680.0,320.0,2434.85,0.000,0.000,0,,0,0,92,4,True,163.0,68.5,0,0.0,1,180.000,15.000000,0.000000,86.027767,99.210333,60.130967,40.591267,14.277767,99.793667,35.688300,96.238900,7.916667,,,,,,2.0,36.000000,0.360000,4.000000,5.000000,495.000000,7.275000,8.666667,24.333333,25.041667,13.333333,,,,,,,,,,5.033333,137.333333,104.666667,125.333333,46.000000,1.400000,2.20000,7.633333,1.133333,33.666667,42.000000,147.888867,0.566667,0.217727,,2.50,0.205000,56.5,8.333333,30.566667,2.748890,25.100000,280.666667,35.866667,15.500000,,1.35000,7.480000,21.000000,59.000000,10.000000,1.844443,25.000000,,,12.333333,,438.500000,,2.366667,,,,411.333333,26.000000,,,,,,,,,,,,,,,0.871841,58.333300,7,2,2
3,30000484,2432.85,678.00,1300.0,620.0,3732.85,0.125,0.150,0,,0,0,92,4,True,163.0,68.5,0,0.0,1,180.000,15.000000,0.000000,92.033333,116.289000,67.250000,46.063900,15.327767,99.800000,36.390700,97.503333,6.666667,,,,,,2.0,36.000000,0.360000,4.000000,5.000000,412.166667,7.337500,8.000000,20.916667,20.416667,15.166667,,,,,,,,,,5.233333,136.000000,104.000000,94.000000,47.000000,1.200000,2.30000,7.800000,1.216667,29.000000,50.000000,32.333300,0.300000,0.032820,,2.50,0.230000,56.5,8.100000,24.600000,3.082223,24.200000,357.000000,36.100000,16.200000,,1.40000,7.456667,21.000000,59.000000,2.333333,1.533333,27.000000,,,10.000000,,676.666667,,1.900000,,,,419.000000,11.000000,,,,,,,,,,,,,,,0.794110,58.333300,6,2,3
4,30000484,2527.85,95.00,1910.0,610.0,4437.85,0.048,0.048,0,,0,0,92,4,True,163.0,68.5,0,0.0,1,180.000,14.714300,-0.285714,96.530567,126.625000,82.905567,61.045833,20.269433,100.000000,36.608300,97.895000,7.533333,,,,,,2.0,36.000000,0.360000,4.000000,6.000000,437.733333,8.647777,7.666667,12.622233,16.000000,19.000000,,,,,,,,,,4.811110,137.333333,104.666667,108.074000,43.666667,1.255557,2.30000,8.077777,1.113333,31.000000,38.333333,20.666667,0.355556,0.071342,5.2,2.70,0.213333,56.5,8.700000,25.822233,3.313333,21.533333,343.666667,39.155567,16.088900,,1.40000,7.463333,136.000000,37.000000,4.333333,1.766667,25.666667,,,10.000000,,842.000000,,2.011110,,,,308.444333,7.666667,,,,,,,,,,,,,,,0.762922,377.778000,2,3,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
418142,39999858,290.00,290.00,5250.0,1900.0,5540.00,0.000,0.000,0,,0,0,62,4,False,,123.8,0,,0,248.083,15.000000,0.000000,69.842867,115.781000,76.980967,57.580967,22.314267,95.095233,36.994733,98.590500,,,,,,,2.0,54.261900,0.542619,40.000000,8.098033,359.533333,7.354510,10.768633,19.000000,20.000000,20.000000,,,,,,,,,,3.700000,138.000000,101.000000,158.676333,19.000000,0.600000,1.90000,9.100000,1.200000,28.000000,41.000000,47.000000,0.500000,0.171500,,3.30,0.216667,,12.400000,37.600000,4.150000,7.100000,170.000000,28.700000,12.900000,,1.20000,7.426667,190.100000,50.000000,3.866667,2.233333,25.333333,42.333333,,12.000000,,,,4.000000,,,,373.000000,4.333333,,,,,,,,,,,,0.21,,,0.603307,356.835333,1,1,5
418143,39999858,290.00,0.00,5925.0,675.0,6215.00,0.000,0.000,0,,0,0,62,4,False,,123.8,0,,0,248.083,14.689467,-0.310541,61.084433,115.248667,75.798733,56.073800,23.205133,92.512833,36.805400,98.249667,,,,,,,2.0,39.102567,0.391026,32.424233,8.000000,328.600000,8.840000,12.600000,19.000000,17.676467,25.000000,,,,,,,,,,3.711110,137.333333,100.666667,203.889000,18.777767,0.600000,1.90000,9.077777,1.168890,28.000000,40.888900,46.777767,0.533333,0.194613,,3.30,0.630000,,12.266667,37.500000,4.135557,7.200000,171.666667,29.200000,13.000000,,1.21111,7.456667,163.866667,55.333333,2.400000,0.900000,25.333333,44.000000,,11.777767,,,,3.977777,,,,379.666667,3.000000,,,,,,,,,,,,0.21,,,0.533346,420.389000,1,1,6
418144,39999858,580.00,290.00,7525.0,1600.0,8105.00,0.000,0.000,0,,0,0,62,4,False,,123.8,0,,0,248.083,15.000000,0.000000,72.248667,118.121333,75.365100,53.986800,27.571433,92.904767,36.866833,98.360333,,,,,,,2.0,40.000000,0.400000,19.285720,7.000000,372.333333,6.641177,9.066667,16.254900,23.552933,18.000000,,,,,,,,,,3.800000,136.000000,100.000000,175.767000,17.000000,0.600000,1.90000,8.900000,1.161667,28.666667,40.000000,45.000000,0.800000,0.379520,,3.30,0.423333,,12.000000,36.700000,4.020000,8.000000,185.000000,33.200000,13.800000,,1.30000,7.470000,149.533333,57.000000,3.733333,1.100000,24.333333,38.000000,,10.000000,,,,3.800000,,,,433.000000,4.000000,,,,,,,,,,,,0.21,,,0.612765,373.833333,1,1,7
418145,39999858,580.00,0.00,7875.0,350.0,8455.00,0.000,0.000,0,,0,0,62,4,False,,123.8,0,,0,248.083,15.000000,0.000000,61.407933,113.454667,72.595233,52.165467,24.160333,93.206333,36.763533,98.174300,,,,,,,2.0,41.333333,0.413333,6.000000,4.000000,350.000000,6.400000,5.000000,14.000000,20.176467,18.000000,,,,,,,,,,3.777777,135.333333,99.000000,169.285667,17.222233,0.611111,1.88889,8.900000,1.146667,28.000000,39.444433,45.000000,0.800000,0.379520,,3.30,0.423333,,12.233333,36.888900,4.020000,8.055557,188.333333,32.600000,13.855567,,1.30000,7.433333,178.300000,66.000000,5.600000,0.900000,26.000000,39.666667,,10.111100,,,,3.744443,,,,428.666667,3.666667,,,,,,,,,,,,,,,0.541267,432.242333,1,1,8


In [18]:
patient_day_counts = MIMICtable.groupby('icustayid')['bloc'].nunique().reset_index()
patient_day_counts.rename(columns={'bloc': 'num_blocs'}, inplace=True)

patients_with_2_days = patient_day_counts[patient_day_counts['num_blocs'] >= 2]['icustayid']
# Step 5: Filter the aggregated data
MIMICtable = MIMICtable[MIMICtable['icustayid'].isin(patients_with_2_days)].reset_index(drop=True)


In [19]:
MIMICtable.to_parquet(os.path.join(in_dir, "MIMIC_action.parquet"))

In [20]:
import pandas as pd
import numpy as np

# Parameters
fixed_num_features = 40

# Load feature importance
feature_importance = pd.read_csv('/home/lkapral/RRT_mimic_iv/data/model/combined_feature_importances.csv')



# Sort by 'Combined_Average' descending to identify top features
feature_importance_sorted = feature_importance.sort_values(by='Combined_Average', ascending=False)

# Extract the top N features and compute normalized weights
top_features = feature_importance_sorted.head(fixed_num_features)
weights = top_features['Combined_Average'].values
normalized_weights = weights / np.linalg.norm(weights)

# Create dictionary of feature to weight
feature_to_weight = {}
for i, feat in enumerate(top_features['Feature']):
    feature_to_weight[feat] = normalized_weights[i]

# Assign weight 0 to all other features
all_features = feature_importance['Feature'].tolist()
for feat in all_features:
    if feat not in feature_to_weight:
        feature_to_weight[feat] = 0.0

# Compute statistics from MIMICtable for each feature
total_rows = MIMICtable.shape[0]
mean_list = []
std_list = []
missingness_list = []

for feat in feature_importance['Feature']:
    col_data = MIMICtable[feat]
    feat_mean = col_data.mean()
    feat_std = col_data.std()
    missing_count = col_data.isnull().sum()
    # Convert missing count to percentage
    missing_percent = (missing_count / total_rows) * 100

    mean_list.append(feat_mean)
    std_list.append(feat_std)
    missingness_list.append(missing_percent)

# Create the final dataframe
final_df = pd.DataFrame({
    'Feature': feature_importance['Feature'],
    'Mean': mean_list,
    'Std': std_list,
    'Missingness (%)': missingness_list,
    'Feature weight': [feature_to_weight[feat] for feat in feature_importance['Feature']]
})

# Combine Mean and Std into "Mean (SD)" column
final_df['Mean (SD)'] = final_df['Mean'].round(2).astype(str) + " ± " + final_df['Std'].round(2).astype(str)
final_df.drop(['Mean','Std'], axis=1, inplace=True)

# Round feature weight and missingness
final_df['Feature weight'] = final_df['Feature weight'].round(4)
final_df['Missingness (%)'] = final_df['Missingness (%)'].round(1)

# Create a mapping of old feature names to the updated names:
feature_name_mapping = {
    'output_step': '12-hour total output, mL',
    'SOFA': 'SOFA score',
    'cumulated_balance': 'Cumulative balance, mL',
    'Creatinine': 'Creatinine, mg/dL',
    'Platelets_count': 'Platelet count, ×10^3/µL',
    'Chloride': 'Chloride, mEq/L',
    'BUN': 'BUN, mg/dL',
    'Anion_Gap': 'Anion gap, mEq/L',
    'Calcium': 'Calcium, mg/dL',
    'input_total': 'Total input, mL',
    'WBC_count': 'WBC count, ×10^3/µL',
    'Total_bili': 'Total bilirubin, mg/dL',
    'Phosphorous': 'Phosphorus, mg/dL',
    'O2flow': 'O2 flow, L/min',
    'output_total': 'Total output, mL',
    'Weight_kg': 'Weight, kg',
    'RASS': 'RASS score',
    'Sodium': 'Sodium, mEq/L',
    'Temp_C': 'Temperature, °C',
    'age': 'Age, years',
    'max_dose_vaso': 'Maximum vasopressor dose, µg/kg/min',
    'PAWmean': 'Mean airway pressure, cmH2O',
    'GCS': 'GCS score',
    'SGOT': 'AST (SGOT), U/L',
    'PT': 'PT, s',
    'PTT': 'PTT, s',
    'RBC_count': 'RBC count, ×10^6/µL',
    'LDH': 'LDH, U/L',
    'Ht': 'Hematocrit, %',
    'RR': 'Respiratory rate, breaths/min',
    'HCO3': 'Bicarbonate, mEq/L',
    'SpO2': 'SpO2, %',
    'Ionised_Ca': 'Ionized calcium, mmol/L',
    'Hb': 'Hemoglobin, g/dL',
    'FiO2_1': 'FiO2, %',
    'SGPT': 'ALT (SGPT), U/L',
    'Shock_Index': 'Shock index',
    'Glucose': 'Glucose, mg/dL',
    'HR': 'Heart rate, beats/min',
    'MinuteVentil': 'Minute ventilation, L/min',
    'MeanBP': 'Mean blood pressure, mmHg',
    'INR': 'INR',
    'Potassium': 'Potassium, mEq/L',
    'Fibrinogen': 'Fibrinogen, mg/dL',
    'Arterial_pH': 'Arterial pH',
    'PaO2_FiO2': 'PaO2/FiO2 ratio',
    'TidalVolume': 'Tidal volume, mL',
    'paO2': 'PaO2, mmHg',
    'Albumin': 'Albumin, g/dL',
    'DiaBP': 'Diastolic blood pressure, mmHg',
    'input_step': '12-hour total input, mL',
    'Magnesium': 'Magnesium, mg/dL',
    'SysBP': 'Systolic blood pressure, mmHg',
    'PAWpeak': 'Peak airway pressure, cmH2O',
    'extubated': 'Extubated (yes/no)',
    'Arterial_BE': 'Arterial base excess, mEq/L',
    'PAWplateau': 'Plateau airway pressure, cmH2O',
    'Height_cm': 'Height, cm',
    'CVP': 'cCntral venous pressure, mmHg',
    'paCO2': 'PaCO2, mmHg',
    'Arterial_lactate': 'Arterial lactate, mmol/L',
    'PEEP': 'PEEP, cmH2O',
    'CK_MB': 'CK-MB, ng/mL',
    'ETCO2': 'End-tidal CO2, mmHg',
    'Troponin': 'Troponin, ng/mL',
    'mechvent': 'Mechanical ventilation (yes/no)',
    'Absolute_Neutrophil_Count': 'Absolute neutrophil count, ×10^3/µL',
    'SIRS': 'SIRS criteria',
    'SaO2': 'SaO2, %',
    'Triglyceride': 'Triglycerides, mg/dL',
    'SvO2': 'SvO2, %',
    'PAPsys': 'Pulmonary artery systolic pressure, mmHg',
    'PAPdia': 'Pulmonary artery diastolic pressure, mmHg',
    're_admission': 're-admission (yes/no)',
    'PAPmean': 'Mean pulmonary artery pressure, mmHg',
    'Creatinine_Urine': 'Urine creatinine, mg/dL',
    'gender': 'gender (M/F)',
    'BNP': 'BNP, pg/mL',
    'CRP': 'CRP, mg/L',
    'Urea_Nitrogen_Urine': 'Urine urea nitrogen, mg/dL',
    'Sodium_Urine': 'Urine sodium, mEq/L',
    'Potassium_Urine': 'Urine potassium, mEq/L',
    'Iron': 'Iron, µg/dL',
    'Ammonia': 'Ammonia, µg/dL',
    'Thyroid_Stimulating_Hormone': 'TSH, mIU/L',
    'Total_protein': 'Total protein, g/dL',
    'CI': 'Cardiac index, L/min/m²',
    'ACT': 'ACT, s',
    'T3': 'T3, ng/dL',
    'Gamma_Glutamyltransferase': 'GGT, U/L',
    'Heparin_LMW': 'Low molecular weight heparin (yes/no)',
    'APACHEII_Renal_Failure': 'APACHE II renal failure score',
    'Osmolality_Urine': 'Urine osmolality, mOsm/kg'
}

# Update the Feature column with the new names
final_df['Feature'] = final_df['Feature'].map(feature_name_mapping)

# Save final CSV
output_path = '/home/lkapral/RRT_mimic_iv/data/model/MIMICtable_features_with_weights.csv'
final_df.to_csv(output_path, index=False)


In [21]:
output_path = '/home/lkapral/RRT_mimic_iv/data/model/MIMICtable_features_with_weights.xlsx'
final_df.to_excel(output_path, index=False)

In [22]:


# find patients who died in ICU during data collection period
icuuniqueids = MIMICtable[C_ICUSTAYID].unique()
train_ids, test_ids = train_test_split(icuuniqueids, train_size=args.train_size, random_state=42)
train_indexes = MIMICtable[MIMICtable[C_ICUSTAYID].isin(train_ids)].index
test_indexes = MIMICtable[MIMICtable[C_ICUSTAYID].isin(test_ids)].index
print("Training: {} IDs ({} rows)".format(len(train_ids), len(train_indexes)))
print("Test: {} IDs ({} rows)".format(len(test_ids), len(test_indexes)))

MIMICraw = MIMICtable[ALL_FEATURE_COLUMNS]

print("Proportion of NA values:", MIMICraw.isna().sum() / len(MIMICraw))

normer = DataNormalization(MIMICtable.iloc[train_indexes])
MIMICzs_train = normer.transform(MIMICtable.iloc[train_indexes])
MIMICzs_test = normer.transform(MIMICtable.iloc[test_indexes])

train_dir = os.path.join(out_dir, "train")
test_dir = os.path.join(out_dir, "test")
if not os.path.exists(train_dir):
    os.mkdir(train_dir)
if not os.path.exists(test_dir):
    os.mkdir(test_dir)

metadata = MIMICtable[[C_BLOC, C_ICUSTAYID, args.outcome_col]].rename({args.outcome_col: C_OUTCOME}, axis=1)

# Save files
print("Saving files")
normer.save(os.path.join(out_dir, 'normalization.pkl'))
save_data_files(train_dir,
                MIMICraw.iloc[train_indexes],
                MIMICzs_train,
                metadata.iloc[train_indexes])
save_data_files(test_dir,
                MIMICraw.iloc[test_indexes],
                MIMICzs_test,
                metadata.iloc[test_indexes])    
print("Done.")

Training: 37992 IDs (293094 rows)
Test: 16283 IDs (124469 rows)
Proportion of NA values: gender                       0.000000
mechvent                     0.000000
extubated                    0.640078
max_dose_vaso                0.000000
re_admission                 0.000000
                               ...   
Gamma_Glutamyltransferase    0.998709
input_total                  0.000000
input_step                   0.000000
output_total                 0.000000
output_step                  0.000000
Length: 101, dtype: float64
Saving files
Done.


In [23]:
MIMICtable

Unnamed: 0,icustayid,input_total,input_step,output_total,output_step,cumulated_balance,median_dose_vaso,max_dose_vaso,mechvent,extubated,action,gender,age,elixhauser,re_admission,Height_cm,Weight_kg,died_in_hosp,died_within_48h_of_out_time,morta_90,delay_end_of_record_and_discharge_or_death,GCS,RASS,HR,SysBP,MeanBP,DiaBP,RR,SpO2,Temp_C,Temp_F,CVP,PAPsys,PAPmean,PAPdia,CI,SVR,Interface,FiO2_100,FiO2_1,O2flow,PEEP,TidalVolume,MinuteVentil,PAWmean,PAWpeak,PAWplateau,Respiratory_Rate,Ultrafiltrate_Output,Blood_Flow,Hourly_Patient_Fluid_Removal,Dialysate_Rate,APACHEII_Renal_Failure,Hemodialysis_Output,Citrate,Prefilter_Replacement_Rate,Postfilter_Replacement_Rate,Potassium,Sodium,Chloride,Glucose,BUN,Creatinine,Magnesium,Calcium,Ionised_Ca,CO2_mEqL,SGOT,SGPT,Total_bili,Direct_bili,Total_protein,Albumin,Troponin,CRP,Hb,Ht,RBC_count,WBC_count,Platelets_count,PTT,PT,ACT,INR,Arterial_pH,paO2,paCO2,Arterial_BE,Arterial_lactate,HCO3,ETCO2,SvO2,Anion_Gap,Ammonia,Fibrinogen,Absolute_Neutrophil_Count,Phosphorous,SaO2,Triglyceride,ScvO2,LDH,CK_MB,BNP,Iron,Thyroid_Stimulating_Hormone,Creatinine_Urine,Potassium_Urine,Sodium_Urine,Urea_Nitrogen_Urine,Creatinine_Clearance,T3,Gamma_Glutamyltransferase,Myoglobin,Heparin_LMW,Osmolality_Urine,Insulin,Shock_Index,PaO2_FiO2,SOFA,SIRS,bloc
0,30000484,250.00,250.00,360.0,360.0,610.00,0.000,0.000,0,,0,0,92,4,True,163.0,68.5,0,0.0,1,180.000,14.500000,-0.500000,88.651550,111.950000,75.700000,57.575000,15.522750,97.000000,36.277800,97.300000,12.214300,,,,,,2.0,34.000000,0.340000,3.500000,7.500000,732.625000,9.900000,9.500000,13.000000,19.333350,16.866650,,,,,,,,,,4.000000,142.500000,102.500000,93.000000,15.000000,0.500000,2.05000,8.550000,1.100000,31.000000,30.500000,86.500000,1.000000,0.518200,,2.75,0.135000,56.5,10.750000,33.350000,3.135000,12.300000,63.000000,82.525650,14.800000,,1.30000,7.448000,77.900000,46.900000,-3.000000,1.110000,23.500000,,,11.000000,,349.000000,,2.750000,,,,522.500000,1.500000,,,,,,,,,,,,,,,0.813359,240.138900,6,2,1
1,30000484,1754.85,1504.85,680.0,320.0,2434.85,0.000,0.000,0,,0,0,92,4,True,163.0,68.5,0,0.0,1,180.000,15.000000,0.000000,86.027767,99.210333,60.130967,40.591267,14.277767,99.793667,35.688300,96.238900,7.916667,,,,,,2.0,36.000000,0.360000,4.000000,5.000000,495.000000,7.275000,8.666667,24.333333,25.041667,13.333333,,,,,,,,,,5.033333,137.333333,104.666667,125.333333,46.000000,1.400000,2.20000,7.633333,1.133333,33.666667,42.000000,147.888867,0.566667,0.217727,,2.50,0.205000,56.5,8.333333,30.566667,2.748890,25.100000,280.666667,35.866667,15.500000,,1.35000,7.480000,21.000000,59.000000,10.000000,1.844443,25.000000,,,12.333333,,438.500000,,2.366667,,,,411.333333,26.000000,,,,,,,,,,,,,,,0.871841,58.333300,7,2,2
2,30000484,2432.85,678.00,1300.0,620.0,3732.85,0.125,0.150,0,,0,0,92,4,True,163.0,68.5,0,0.0,1,180.000,15.000000,0.000000,92.033333,116.289000,67.250000,46.063900,15.327767,99.800000,36.390700,97.503333,6.666667,,,,,,2.0,36.000000,0.360000,4.000000,5.000000,412.166667,7.337500,8.000000,20.916667,20.416667,15.166667,,,,,,,,,,5.233333,136.000000,104.000000,94.000000,47.000000,1.200000,2.30000,7.800000,1.216667,29.000000,50.000000,32.333300,0.300000,0.032820,,2.50,0.230000,56.5,8.100000,24.600000,3.082223,24.200000,357.000000,36.100000,16.200000,,1.40000,7.456667,21.000000,59.000000,2.333333,1.533333,27.000000,,,10.000000,,676.666667,,1.900000,,,,419.000000,11.000000,,,,,,,,,,,,,,,0.794110,58.333300,6,2,3
3,30000484,2527.85,95.00,1910.0,610.0,4437.85,0.048,0.048,0,,0,0,92,4,True,163.0,68.5,0,0.0,1,180.000,14.714300,-0.285714,96.530567,126.625000,82.905567,61.045833,20.269433,100.000000,36.608300,97.895000,7.533333,,,,,,2.0,36.000000,0.360000,4.000000,6.000000,437.733333,8.647777,7.666667,12.622233,16.000000,19.000000,,,,,,,,,,4.811110,137.333333,104.666667,108.074000,43.666667,1.255557,2.30000,8.077777,1.113333,31.000000,38.333333,20.666667,0.355556,0.071342,5.2,2.70,0.213333,56.5,8.700000,25.822233,3.313333,21.533333,343.666667,39.155567,16.088900,,1.40000,7.463333,136.000000,37.000000,4.333333,1.766667,25.666667,,,10.000000,,842.000000,,2.011110,,,,308.444333,7.666667,,,,,,,,,,,,,,,0.762922,377.778000,2,3,4
4,30000484,2527.85,0.00,2320.0,410.0,4847.85,0.000,0.000,0,,0,0,92,4,True,163.0,68.5,0,0.0,1,180.000,14.714300,-0.285714,90.483333,117.816667,77.816667,54.416667,19.733333,99.800000,36.270333,97.286667,7.000000,,,,,,2.0,30.933333,0.309333,2.733333,3.333333,386.666667,7.066667,7.000000,15.333333,15.000000,19.000000,,,,,,,,,,4.500000,137.000000,104.333333,121.666667,41.000000,1.300000,2.30000,8.300000,1.183333,24.333333,29.000000,12.000000,0.400000,0.102160,5.2,2.80,0.200000,,9.037970,26.800000,2.960000,19.400000,333.000000,41.600000,16.000000,,1.40000,7.470000,81.000000,43.000000,-2.666667,1.366667,27.666667,,,10.000000,,842.000000,,2.100000,,,,220.000000,6.666667,,,,,,,,,,,,,,,0.769480,251.433800,3,2,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
417558,39999858,290.00,290.00,5250.0,1900.0,5540.00,0.000,0.000,0,,0,0,62,4,False,,123.8,0,,0,248.083,15.000000,0.000000,69.842867,115.781000,76.980967,57.580967,22.314267,95.095233,36.994733,98.590500,,,,,,,2.0,54.261900,0.542619,40.000000,8.098033,359.533333,7.354510,10.768633,19.000000,20.000000,20.000000,,,,,,,,,,3.700000,138.000000,101.000000,158.676333,19.000000,0.600000,1.90000,9.100000,1.200000,28.000000,41.000000,47.000000,0.500000,0.171500,,3.30,0.216667,,12.400000,37.600000,4.150000,7.100000,170.000000,28.700000,12.900000,,1.20000,7.426667,190.100000,50.000000,3.866667,2.233333,25.333333,42.333333,,12.000000,,,,4.000000,,,,373.000000,4.333333,,,,,,,,,,,,0.21,,,0.603307,356.835333,1,1,5
417559,39999858,290.00,0.00,5925.0,675.0,6215.00,0.000,0.000,0,,0,0,62,4,False,,123.8,0,,0,248.083,14.689467,-0.310541,61.084433,115.248667,75.798733,56.073800,23.205133,92.512833,36.805400,98.249667,,,,,,,2.0,39.102567,0.391026,32.424233,8.000000,328.600000,8.840000,12.600000,19.000000,17.676467,25.000000,,,,,,,,,,3.711110,137.333333,100.666667,203.889000,18.777767,0.600000,1.90000,9.077777,1.168890,28.000000,40.888900,46.777767,0.533333,0.194613,,3.30,0.630000,,12.266667,37.500000,4.135557,7.200000,171.666667,29.200000,13.000000,,1.21111,7.456667,163.866667,55.333333,2.400000,0.900000,25.333333,44.000000,,11.777767,,,,3.977777,,,,379.666667,3.000000,,,,,,,,,,,,0.21,,,0.533346,420.389000,1,1,6
417560,39999858,580.00,290.00,7525.0,1600.0,8105.00,0.000,0.000,0,,0,0,62,4,False,,123.8,0,,0,248.083,15.000000,0.000000,72.248667,118.121333,75.365100,53.986800,27.571433,92.904767,36.866833,98.360333,,,,,,,2.0,40.000000,0.400000,19.285720,7.000000,372.333333,6.641177,9.066667,16.254900,23.552933,18.000000,,,,,,,,,,3.800000,136.000000,100.000000,175.767000,17.000000,0.600000,1.90000,8.900000,1.161667,28.666667,40.000000,45.000000,0.800000,0.379520,,3.30,0.423333,,12.000000,36.700000,4.020000,8.000000,185.000000,33.200000,13.800000,,1.30000,7.470000,149.533333,57.000000,3.733333,1.100000,24.333333,38.000000,,10.000000,,,,3.800000,,,,433.000000,4.000000,,,,,,,,,,,,0.21,,,0.612765,373.833333,1,1,7
417561,39999858,580.00,0.00,7875.0,350.0,8455.00,0.000,0.000,0,,0,0,62,4,False,,123.8,0,,0,248.083,15.000000,0.000000,61.407933,113.454667,72.595233,52.165467,24.160333,93.206333,36.763533,98.174300,,,,,,,2.0,41.333333,0.413333,6.000000,4.000000,350.000000,6.400000,5.000000,14.000000,20.176467,18.000000,,,,,,,,,,3.777777,135.333333,99.000000,169.285667,17.222233,0.611111,1.88889,8.900000,1.146667,28.000000,39.444433,45.000000,0.800000,0.379520,,3.30,0.423333,,12.233333,36.888900,4.020000,8.055557,188.333333,32.600000,13.855567,,1.30000,7.433333,178.300000,66.000000,5.600000,0.900000,26.000000,39.666667,,10.111100,,,,3.744443,,,,428.666667,3.666667,,,,,,,,,,,,,,,0.541267,432.242333,1,1,8


In [24]:
MIMICtable.groupby('icustayid')['bloc'].max().min()

2

0.16473514509442652