In [1]:
import psycopg2
from datetime import timedelta
from sqlalchemy import create_engine
import psycopg2
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import cross_val_score
from sklearn.metrics import precision_score, recall_score,f1_score, accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
from sktime.transformations.panel.rocket import Rocket
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import VotingClassifier
from sklearn.pipeline import Pipeline

from imblearn.ensemble import BalancedRandomForestClassifier
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler

In [2]:
import psycopg2

conn = psycopg2.connect(
    host="localhost",
    database="mimic",
    user="postgres",
    password="postgres"
)

cur = conn.cursor()

cur.execute("SELECT version();")
print(cur.fetchone())

('PostgreSQL 15.2, compiled by Visual C++ build 1914, 64-bit',)


In [3]:
# Connect to db
conn = psycopg2.connect(host='localhost', dbname='mimic', user='postgres', password='postgres', options='-c search_path=mimiciii')
cur = conn.cursor() 

# Read in table with patients & admissions (inner join on subject_id) and icu_stays (inner joinon subject_id and hadm_id)
icustay_details = pd.read_sql_query("SELECT * FROM mimiciii.flicu_icustay_detail;", conn)

# Read in vital and lab signs
pivoted_vital = pd.read_sql_query("SELECT * FROM mimiciii.pivoted_vital;", conn)
pivoted_lab = pd.read_sql_query("SELECT * FROM mimiciii.ckd_pivoted_lab;", conn)

# Close the cursor and connection to so the server can allocate bandwidth to other requests
cur.close()
conn.close()

  icustay_details = pd.read_sql_query("SELECT * FROM mimiciii.flicu_icustay_detail;", conn)
  pivoted_vital = pd.read_sql_query("SELECT * FROM mimiciii.pivoted_vital;", conn)
  pivoted_lab = pd.read_sql_query("SELECT * FROM mimiciii.ckd_pivoted_lab;", conn)


In [4]:
WINDOW_LENGTH = 96

In [5]:
data= icustay_details.copy()
data = data[data.los_icu >= WINDOW_LENGTH/24.0]

In [6]:
filtered_icustay_ids = pd.DataFrame(data['icustay_id'].unique(), columns=['icustay_id'])

In [7]:
# Drop measurements with no belonging icustay_id
pivoted_vital = pivoted_vital.dropna(subset=['icustay_id'])
pivoted_lab = pivoted_lab.dropna(subset=['icustay_id'])

# Cast icustay_id types to int
pivoted_vital['icustay_id'] = pivoted_vital['icustay_id'].astype(int)
pivoted_lab['icustay_id'] = pivoted_lab['icustay_id'].astype(int)

# Keep only values of patients in previously filtered icustay_ids in labs and vitals
pivoted_vital = pivoted_vital.merge(filtered_icustay_ids, on='icustay_id', how='right').drop_duplicates()
pivoted_lab = pivoted_lab.merge(filtered_icustay_ids, on='icustay_id', how='right').drop_duplicates()

In [8]:
# Min of each lab and vitals
icustay_ids_charttime_min_lab = pivoted_lab[["icustay_id", "charttime"]][pivoted_lab.groupby("icustay_id")["charttime"].rank(ascending=1,method='dense') == 1]
icustay_ids_charttime_min_vital = pivoted_vital[["icustay_id", "charttime"]][pivoted_vital.groupby("icustay_id")["charttime"].rank(ascending=1,method='dense') == 1]
# Min of both combined
icustay_ids_charttime_min_vital_lab = pd.concat([icustay_ids_charttime_min_lab, icustay_ids_charttime_min_vital], ignore_index=True)
icustay_ids_charttime_min_vital_lab = icustay_ids_charttime_min_vital_lab[["icustay_id", "charttime"]][icustay_ids_charttime_min_vital_lab.groupby("icustay_id")["charttime"].rank(ascending=1,method='dense') == 1]

# Max of each lab and vitals
icustay_ids_charttime_max_lab = pivoted_lab[["icustay_id", "charttime"]][pivoted_lab.groupby("icustay_id")["charttime"].rank(ascending=0,method='dense') == 1]
icustay_ids_charttime_max_vital = pivoted_vital[["icustay_id", "charttime"]][pivoted_vital.groupby("icustay_id")["charttime"].rank(ascending=0,method='dense') == 1]
# Max of both combined
icustay_ids_charttime_max_vital_lab = pd.concat([icustay_ids_charttime_max_lab, icustay_ids_charttime_max_vital], ignore_index=True)
icustay_ids_charttime_max_vital_lab = icustay_ids_charttime_max_vital_lab[["icustay_id", "charttime"]][icustay_ids_charttime_max_vital_lab.groupby("icustay_id")["charttime"].rank(ascending=0,method='dense') == 1]

In [9]:
# Find for which icustay_ids there exist at least WINDOW_LENGTH of data
icustay_ids_vital_lab_charttime_min_max = pd.concat([icustay_ids_charttime_max_vital_lab, icustay_ids_charttime_min_vital_lab], ignore_index=True)
time_window = timedelta(days=4, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=WINDOW_LENGTH, weeks=0)
is_time_diff_bigger_window_lab = icustay_ids_vital_lab_charttime_min_max.groupby(['icustay_id'])['charttime'].transform(lambda x: (x.max()-x.min())) >= time_window

icustay_ids_vital_lab_charttime_min_max_filtered = icustay_ids_vital_lab_charttime_min_max[is_time_diff_bigger_window_lab]
print("Unique icu stays in icustay_ids_vital_lab_charttime_min_max_filtered after filtering", icustay_ids_vital_lab_charttime_min_max_filtered['icustay_id'].nunique())

# Keep only icustay ids for which at least WINDOW_LENGTH of data exists
icustay_ids_time_filtered = pd.DataFrame(icustay_ids_vital_lab_charttime_min_max_filtered['icustay_id'].unique(), columns=['icustay_id'])
print("Unique icu stays in icustay_ids_time_filtered: ", icustay_ids_time_filtered['icustay_id'].nunique())

Unique icu stays in icustay_ids_vital_lab_charttime_min_max_filtered after filtering 8409
Unique icu stays in icustay_ids_time_filtered:  8409


In [10]:
filtered_icustay_ids = filtered_icustay_ids.merge(icustay_ids_time_filtered, on='icustay_id', how='inner').drop_duplicates()

In [11]:
demographics_filtered = data.merge(filtered_icustay_ids, on='icustay_id', how='right').drop_duplicates()
print("Number of ICU stays demographics: ", demographics_filtered['icustay_id'].nunique())

vital_filtered = pivoted_vital.merge(filtered_icustay_ids, on='icustay_id', how='right').drop_duplicates()
print("Number of ICU stays vitals: ", vital_filtered['icustay_id'].nunique())

lab_filtered = pivoted_lab.merge(filtered_icustay_ids, on='icustay_id', how='right').drop_duplicates()
print("Number of ICU stays labs: ", lab_filtered['icustay_id'].nunique())

Number of ICU stays demographics:  8409
Number of ICU stays vitals:  8409
Number of ICU stays labs:  8409


In [12]:
vital_filtered = vital_filtered.merge(lab_filtered[['icustay_id', 'charttime']], on=['icustay_id', 'charttime'], how='outer').drop_duplicates()
print("Number of ICU stays in lab_filtered: ", vital_filtered['icustay_id'].nunique())
lab_filtered = lab_filtered.merge(vital_filtered[['icustay_id', 'charttime']], on=['icustay_id', 'charttime'], how='outer').drop_duplicates()
print("Number of ICU stays in lab_filtered: ", lab_filtered['icustay_id'].nunique())

Number of ICU stays in lab_filtered:  8409
Number of ICU stays in lab_filtered:  8409


In [13]:
vital_resampled = vital_filtered.copy()

# Resample from the end of the time series (how="last")
vital_resampled = vital_resampled.assign(charttime=vital_resampled.charttime.dt.round('H'))

# Resample from the beginning of the time series
vital_resampled = vital_resampled.set_index('charttime').groupby('icustay_id').resample('1H', origin="start").median().drop(['icustay_id'], axis = 1).reset_index()

# Forward and backwards fill (use lambda function instead of directly applying it to groupby otherwise results from one group are carreid forward to another group...BAD)
# Fill NaNs (-1)
vital_col = vital_resampled.columns.drop(['icustay_id', 'charttime'])
vital_resampled = vital_resampled.set_index(['icustay_id', 'charttime']).groupby('icustay_id')[vital_col].transform(lambda x: x.ffill().bfill()).fillna(value=vital_resampled[['icustay_id', 'charttime', 'heartrate', 'sysbp', 'diasbp', 'meanbp','resprate', 'tempc', 'spo2', 'glucose', 'rbc', 'specificgravity','pedaledema', 'appetite_median']].median()).reset_index()


  vital_resampled = vital_resampled.set_index(['icustay_id', 'charttime']).groupby('icustay_id')[vital_col].transform(lambda x: x.ffill().bfill()).fillna(value=vital_resampled[['icustay_id', 'charttime', 'heartrate', 'sysbp', 'diasbp', 'meanbp','resprate', 'tempc', 'spo2', 'glucose', 'rbc', 'specificgravity','pedaledema', 'appetite_median']].median()).reset_index()


In [14]:
lab_resampled = lab_filtered.copy()
# Cut out minutes and hours, so that the resampling of the 8h takes the same time span as the 1h samples (for vitals)
lab_resampled = lab_resampled.assign(charttime=lab_resampled.charttime.dt.round('H'))
# Resample from the end of the time series 
#lab_resampled = lab_resampled.set_index('charttime').groupby('icustay_id').resample('8h', origin="end").median().drop(['icustay_id'], axis = 1).reset_index()
lab_resampled = lab_resampled.set_index('charttime').groupby('icustay_id').resample('8h', origin="start").median().drop(['icustay_id'], axis = 1).reset_index()

# Forward and backwards fill (use transform instead of direct groupby otherwise results from one group are carreid forward to another group...BAD)
# Fill NaNs (-1 or 0 or mean!?)
lab_col = lab_resampled.columns.drop(['icustay_id', 'charttime'])
lab_resampled = lab_resampled.set_index(['icustay_id', 'charttime']).groupby('icustay_id')[lab_col].transform(lambda x: x.ffill().bfill()).fillna(value=lab_resampled[['icustay_id', 'subject_id', 'charttime', 'aniongap', 'albumin', 'bands','bicarbonate', 'bilirubin', 'creatinine', 'chloride', 'glucose','hematocrit', 'hemoglobin', 'lactate', 'platelet', 'potassium', 'ptt','inr', 'pt', 'sodium', 'bun', 'wbc', 'bacteria']].median()).reset_index()

print(lab_resampled.isnull().sum().sum())

  lab_resampled = lab_resampled.set_index(['icustay_id', 'charttime']).groupby('icustay_id')[lab_col].transform(lambda x: x.ffill().bfill()).fillna(value=lab_resampled[['icustay_id', 'subject_id', 'charttime', 'aniongap', 'albumin', 'bands','bicarbonate', 'bilirubin', 'creatinine', 'chloride', 'glucose','hematocrit', 'hemoglobin', 'lactate', 'platelet', 'potassium', 'ptt','inr', 'pt', 'sodium', 'bun', 'wbc', 'bacteria']].median()).reset_index()


730


In [15]:
delta_t_data = timedelta(days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=WINDOW_LENGTH, weeks=0)
demographics_windowed = demographics_filtered.copy()
demographics_windowed['predtime'] = demographics_windowed.intime + delta_t_data
demographics_windowed['delta_t_pred'] = demographics_windowed.outtime - demographics_windowed.predtime

demographics_windowed[['subject_id', 'icustay_id', 'intime', 'predtime', 'delta_t_pred']].head(5)

Unnamed: 0,subject_id,icustay_id,intime,predtime,delta_t_pred
0,334,214236,2136-01-16 10:56:48,2136-01-20 10:56:48,10 days 07:21:18
1,2005,285731,2163-06-23 11:28:06,2163-06-27 11:28:06,5 days 08:45:56
2,12174,284866,2118-10-30 16:48:57,2118-11-03 16:48:57,13 days 00:44:12
3,13535,205010,2196-10-10 22:03:14,2196-10-14 22:03:14,88 days 19:52:36
4,21824,241223,2107-07-07 20:58:00,2107-07-11 20:58:00,31 days 15:33:00


In [16]:
cut_icustay_ids = pd.DataFrame(demographics_windowed['icustay_id'].unique(), columns=['icustay_id'])
print("Number of ICU stays: ", cut_icustay_ids['icustay_id'].count())

vitals_cut = vital_resampled.merge(cut_icustay_ids, on='icustay_id', how='right')
print("Number of ICU stays in vitals_cut: ", vitals_cut['icustay_id'].nunique())

labs_cut = lab_resampled.merge(cut_icustay_ids, on='icustay_id', how='right')
print("Number of ICU stays in labs_cut: ", labs_cut['icustay_id'].nunique())


Number of ICU stays:  8409
Number of ICU stays in vitals_cut:  8409
Number of ICU stays in labs_cut:  8409


In [17]:
vitals_windowed = vital_resampled.merge(demographics_windowed[['icustay_id', 'predtime', 'delta_t_pred']], on='icustay_id', how='right')
vitals_windowed = vitals_windowed[vitals_windowed.charttime < vitals_windowed.predtime]
print("Number of ICU stays in vitals_windowed: ", vitals_windowed['icustay_id'].nunique())

labs_windowed = lab_resampled.merge(demographics_windowed[['icustay_id', 'predtime', 'delta_t_pred']], on='icustay_id', how='right')
labs_windowed = labs_windowed[labs_windowed.charttime < labs_windowed.predtime]
print("Number of ICU stays in labs_windowed: ", labs_windowed['icustay_id'].nunique())

windowed_icustay_ids = pd.DataFrame(pd.concat([vitals_windowed['icustay_id'], labs_windowed['icustay_id']]).unique(), columns=['icustay_id'])
demographics_windowed = demographics_windowed.merge(windowed_icustay_ids, on='icustay_id', how='right')

Number of ICU stays in vitals_windowed:  8405
Number of ICU stays in labs_windowed:  8405


In [18]:
vitals_windowed['ckd'] = vitals_windowed['icustay_id'].map(demographics_windowed.set_index('icustay_id')['ckd'])

In [19]:
vitals_windowed =vitals_windowed.set_index(['icustay_id', 'charttime']).groupby('icustay_id')[vital_col].transform(lambda x: x.ffill().bfill()).fillna(-1).reset_index()

In [20]:
labs_windowed['ckd'] = labs_windowed['icustay_id'].map(demographics_windowed.set_index('icustay_id')['ckd'])

In [21]:
print("Number of ICU stays demographics: ", demographics_windowed['icustay_id'].nunique())
print("Number of CKD demographics:")
dd = demographics_windowed[['icustay_id','ckd']].drop_duplicates(subset=['icustay_id'])
print(dd['ckd'].value_counts())

print("Number of ICU stays vitals: ", vitals_windowed['icustay_id'].nunique())
print("Number of CKD vitals:")
dd = vitals_windowed[['icustay_id','ckd']].drop_duplicates(subset=['icustay_id'])
print(dd['ckd'].value_counts())

print("Number of ICU stays labs: ", labs_windowed['icustay_id'].nunique())
print("Number of CKD labs:")
dd = labs_windowed[['icustay_id','ckd']].drop_duplicates(subset=['icustay_id'])
print(dd['ckd'].value_counts())

Number of ICU stays demographics:  8405
Number of CKD demographics:
0    7868
1     537
Name: ckd, dtype: int64
Number of ICU stays vitals:  8405
Number of CKD vitals:
0    7868
1     537
Name: ckd, dtype: int64
Number of ICU stays labs:  8405
Number of CKD labs:
0    7868
1     537
Name: ckd, dtype: int64


In [22]:
def aggregate_dataframe(df, groupby_key, columns_to_aggregate):
    df = df.replace(-1, np.nan)
    result = df.groupby(groupby_key)[columns_to_aggregate].mean().reset_index()    
    return result

In [23]:
columns_to_merge = ['icustay_id', 'ckd','ethnicity_grouped']
df_cols_vitals = ['heartrate', 'sysbp','diasbp','meanbp','resprate','tempc','spo2','specificgravity','pedaledema','appetite_median']
df_agg_vitals = aggregate_dataframe(vitals_windowed, 'icustay_id', df_cols_vitals)
df_agg_vitals = df_agg_vitals.merge(demographics_windowed[columns_to_merge], on='icustay_id', how='inner')
df_agg_vitals['ckd_ethnicity'] = df_agg_vitals['ckd'].astype(str).str.cat(df_agg_vitals['ethnicity_grouped'].astype(str))

df_cols_labs = ['albumin','bacteria','glucose','bun','creatinine','sodium','potassium','hemoglobin','wbc','hematocrit','platelet','ptt']
df_agg_labs = aggregate_dataframe(labs_windowed, 'icustay_id', df_cols_labs)
df_agg_labs = df_agg_labs.merge(demographics_windowed[columns_to_merge], on='icustay_id', how='inner')
df_agg_labs['ckd_ethnicity'] = df_agg_labs['ckd'].astype(str).str.cat(df_agg_labs['ethnicity_grouped'].astype(str))

print("Vitals unique icustay id: ",len(df_agg_vitals['icustay_id'].unique()),"\nLabs unique icustay id: ",len(df_agg_labs['icustay_id'].unique()),"\nDemographics unique icustay id: ",len(demographics_windowed['icustay_id'].unique()))

Vitals unique icustay id:  8405 
Labs unique icustay id:  8405 
Demographics unique icustay id:  8405


In [24]:
df_agg_vitals_new=df_agg_vitals.drop(['ckd','ethnicity_grouped','ckd_ethnicity','pedaledema'],axis=1)
df_agg_labs_new=df_agg_labs.drop(['ckd','ethnicity_grouped','ckd_ethnicity'],axis=1)

In [25]:
merged_table_org = df_agg_labs_new.merge(df_agg_vitals_new, on='icustay_id', how='inner').merge(demographics_windowed, on='icustay_id', how='inner')

In [26]:
merged_table =merged_table_org.copy()

Table names : 
- demographics_windowed
- labs_windowed
- vitals_windowed
- df_agg_vitals
- df_agg_labs
- merged_table_org

In [27]:
# Calculate the difference between max and min charttime for labs_windowed
labs_diff = labs_windowed.groupby('icustay_id')['charttime'].apply(lambda x: x.max() - x.min())

# Calculate the difference between max and min charttime for vitals_windowed
vitals_diff = vitals_windowed.groupby('icustay_id')['charttime'].apply(lambda x: x.max() - x.min())

# Filter the icustay_id where the difference is grater than or equal to Window_length in both labs and vitals
filtered_icustay_ids = labs_diff[(labs_diff == pd.Timedelta(hours=WINDOW_LENGTH)) & (vitals_diff == pd.Timedelta(hours=WINDOW_LENGTH))].index.tolist()

# Print the length of icustay_id
print(len(filtered_icustay_ids))

3038


In [28]:
merged_table_filtered= merged_table[merged_table['icustay_id'].isin(filtered_icustay_ids)]

In [29]:
merged_table_filtered.columns

Index(['icustay_id', 'albumin', 'bacteria', 'glucose', 'bun', 'creatinine',
       'sodium', 'potassium', 'hemoglobin', 'wbc', 'hematocrit', 'platelet',
       'ptt', 'heartrate', 'sysbp', 'diasbp', 'meanbp', 'resprate', 'tempc',
       'spo2', 'specificgravity', 'appetite_median', 'subject_id', 'hadm_id',
       'gender', 'dod', 'admittime', 'dischtime', 'los_hospital',
       'admission_age', 'ethnicity', 'ethnicity_grouped',
       'hospital_expire_flag', 'hospstay_seq', 'first_hosp_stay', 'intime',
       'outtime', 'los_icu', 'icustay_seq', 'first_icu_stay_current_hosp',
       'first_icu_stay_patient', 'first_careunit', 'deathtime_icu',
       'label_death_icu', 'label_cor_art', 'diabetes_mellitus', 'ckd',
       'anemia_flag', 'predtime', 'delta_t_pred'],
      dtype='object')

In [30]:
merged_table_filtered=merged_table_filtered.drop(['subject_id','hadm_id','dod','admittime', 'dischtime','los_hospital','ethnicity','hospital_expire_flag','hospstay_seq', 'first_hosp_stay', 'intime','outtime', 'los_icu', 'icustay_seq', 'first_icu_stay_current_hosp','first_icu_stay_patient', 'first_careunit', 'deathtime_icu','label_death_icu', 'predtime', 'delta_t_pred'],axis=1)

In [31]:
age_ranges = [0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 400]
age_labels = ['0-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60-69', '70-79', '80-89', '90+']
merged_table_filtered['age_group'] = pd.cut(merged_table_filtered['admission_age'], bins=age_ranges, labels=age_labels, right=False)
merged_table_filtered=merged_table_filtered.drop('admission_age',axis=1)

In [32]:
def evaluationCV(classifier,X, y):    
    cv_scores_pr = cross_val_score(classifier, X, y, cv=5, scoring='precision')    
    cv_scores_rc = cross_val_score(classifier, X, y, cv=5, scoring='recall')    
    cv_scores_f1 = cross_val_score(classifier,X, y, cv=5, scoring='f1')
    cv_scores_ac = cross_val_score(classifier, X, y, cv=5, scoring='accuracy') 
    
    print("Cross-validation scores Precision    :", cv_scores_pr)    
    print("Cross-validation scores Recall       :", cv_scores_rc)
    print("Cross-validation scores F1           :", cv_scores_f1)
    print("Cross-validation scores Accuracy     :", cv_scores_ac)
    
    print("Mean cross-validation score Precision:", np.mean(cv_scores_pr))
    print("Mean cross-validation score Recall   :", np.mean(cv_scores_rc))
    print("Mean cross-validation score F1       :", np.mean(cv_scores_f1))
    print("Mean cross-validation score Accuracy :", np.mean(cv_scores_ac))

In [33]:
def evaluationTest(classifier,X, y):  
    y_pred = classifier.predict(X)
    precision = precision_score(y, y_pred)
    recall = recall_score(y, y_pred)
    f1 = f1_score(y, y_pred)
    accuracy = accuracy_score(y, y_pred)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)
    print("Accuracy:", accuracy)
    return pd.DataFrame(y_pred)    

In [34]:
def metricsReport(y,y_pred):
    precision = precision_score(y, y_pred)
    recall = recall_score(y, y_pred)
    f1 = f1_score(y, y_pred)
    accuracy = accuracy_score(y, y_pred)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)
    print("Accuracy:", accuracy)
    return precision, recall , f1, accuracy

#### Data & Class separation

In [35]:
X = merged_table_filtered.drop(['ckd','icustay_id'],axis=1)
y = merged_table_filtered['ckd']

#### Trial 1:  Random forest for static + aggregated timeseries

In [36]:
X_onehot = pd.get_dummies(X)
X_train, X_test, y_train, y_test = train_test_split(X_onehot, y, test_size=0.20, stratify=y, random_state=42)

In [37]:
y_train.value_counts()

0    2305
1     125
Name: ckd, dtype: int64

In [38]:
y_test.value_counts()

0    577
1     31
Name: ckd, dtype: int64

In [39]:
from scipy.stats import randint
from sklearn.model_selection import RandomizedSearchCV

In [40]:
param_grid_rcv = {
    'n_estimators': randint(50, 500),
    'max_features': ['auto', 'sqrt', 'log2'],
    'max_depth' : randint(1, 10),
    'criterion' :['gini', 'entropy']
}

In [41]:
param_grid = {
    'n_estimators': [100, 200, 300],
    'max_depth': [None, 5, 10],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4, 6],
    'max_features': ['sqrt', 'log2','auto']
}

In [42]:
#runder = RandomUnderSampler(random_state=42)
#X_resampled, y_resampled = runder.fit_resample(X_train, y_train)
rf_merged = RandomForestClassifier(random_state=42)

#grid_search = RandomizedSearchCV(estimator=rf_merged, param_distributions=param_grid_rcv, n_iter=100, cv=5, random_state=42)

#grid_search = GridSearchCV(rf_merged, param_grid, cv=5)
#grid_search.fit(X_train, y_train)
#grid_search.best_params_

In [43]:
evaluationCV(rf_merged,X_train, y_train)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Cross-validation scores Precision    : [0. 0. 0. 0. 0.]
Cross-validation scores Recall       : [0. 0. 0. 0. 0.]
Cross-validation scores F1           : [0. 0. 0. 0. 0.]
Cross-validation scores Accuracy     : [0.94855967 0.94444444 0.94855967 0.94855967 0.94650206]
Mean cross-validation score Precision: 0.0
Mean cross-validation score Recall   : 0.0
Mean cross-validation score F1       : 0.0
Mean cross-validation score Accuracy : 0.9473251028806585


In [44]:
rf_merged_2 = RandomForestClassifier(random_state=42,
                                     max_depth = None, 
                                     max_features = 'sqrt',
                                     min_samples_leaf = 1,
                                     min_samples_split = 1,
                                     n_estimators= 100)

rf_merged_2.fit(X_train, y_train)

In [45]:
evaluationCV(rf_merged_2,X_train, y_train)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Cross-validation scores Precision    : [0. 0. 0. 0. 0.]
Cross-validation scores Recall       : [0. 0. 0. 0. 0.]
Cross-validation scores F1           : [0. 0. 0. 0. 0.]
Cross-validation scores Accuracy     : [0.94855967 0.94444444 0.94855967 0.94855967 0.94650206]
Mean cross-validation score Precision: 0.0
Mean cross-validation score Recall   : 0.0
Mean cross-validation score F1       : 0.0
Mean cross-validation score Accuracy : 0.9473251028806585


In [46]:
y_pred_merged_2 = evaluationTest(rf_merged_2,X_test, y_test)

Precision: 0.5
Recall: 0.03225806451612903
F1 Score: 0.06060606060606061
Accuracy: 0.9490131578947368


In [47]:
y_pred_merged_2.value_counts()

0    606
1      2
dtype: int64

#### Trial 2 - RandomUnderSampler

In [48]:
X_top =X[['creatinine', 'specificgravity', 'heartrate', 'bun', 'spo2', 'tempc', 'platelet', 'diasbp', 'bacteria', 'meanbp']]
X_onehot = pd.get_dummies(X_top)
X_train, X_test, y_train, y_test = train_test_split(X_onehot, y, test_size=0.20, stratify=y, random_state=42)

undersampler = RandomUnderSampler()
#undersampler = RandomUnderSampler(sampling_strategy={0: 250})
X_resampled, y_resampled = undersampler.fit_resample(X_train, y_train)

rf_ros = RandomForestClassifier(random_state=42,
                                max_depth = None, 
                                max_features = 'sqrt',
                                min_samples_leaf = 1,
                                min_samples_split = 1,
                                n_estimators= 100)

rf_ros.fit(X_resampled, y_resampled)
evaluationCV(rf_ros,X_resampled, y_resampled)

rf_ros.fit(X_resampled, y_resampled)
y_pred_ros = evaluationTest(rf_ros,X_test, y_test)
y_pred_ros.value_counts()

Cross-validation scores Precision    : [0.72413793 0.73333333 0.75862069 0.78571429 0.7037037 ]
Cross-validation scores Recall       : [0.84 0.88 0.88 0.88 0.76]
Cross-validation scores F1           : [0.77777778 0.8        0.81481481 0.83018868 0.73076923]
Cross-validation scores Accuracy     : [0.76 0.78 0.8  0.82 0.72]
Mean cross-validation score Precision: 0.7411019886881955
Mean cross-validation score Recall   : 0.8480000000000001
Mean cross-validation score F1       : 0.7907101005214212
Mean cross-validation score Accuracy : 0.776
Precision: 0.14351851851851852
Recall: 1.0
F1 Score: 0.25101214574898784
Accuracy: 0.6957236842105263


0    392
1    216
dtype: int64

In [49]:
cm = confusion_matrix(y_test, y_pred_ros)
print(cm)

[[392 185]
 [  0  31]]


In [50]:
X_onehot = pd.get_dummies(X)
X_train, X_test, y_train, y_test = train_test_split(X_onehot, y, test_size=0.20, stratify=y, random_state=42)

undersampler = RandomUnderSampler()
#undersampler = RandomUnderSampler(sampling_strategy={0: 250})
X_resampled, y_resampled = undersampler.fit_resample(X_train, y_train)

rf_ros = RandomForestClassifier(random_state=42,
                                max_depth = None, 
                                max_features = 'sqrt',
                                min_samples_leaf = 1,
                                min_samples_split = 1,
                                n_estimators= 100)

rf_ros.fit(X_resampled, y_resampled)
evaluationCV(rf_ros,X_resampled, y_resampled)

rf_ros.fit(X_resampled, y_resampled)
y_pred_ros = evaluationTest(rf_ros,X_test, y_test)
y_pred_ros.value_counts()

Cross-validation scores Precision    : [0.72413793 0.625      0.75862069 0.73333333 0.82608696]
Cross-validation scores Recall       : [0.84 0.8  0.88 0.88 0.76]
Cross-validation scores F1           : [0.77777778 0.70175439 0.81481481 0.8        0.79166667]
Cross-validation scores Accuracy     : [0.76 0.66 0.8  0.78 0.8 ]
Mean cross-validation score Precision: 0.7334357821089454
Mean cross-validation score Recall   : 0.8320000000000001
Mean cross-validation score F1       : 0.7772027290448345
Mean cross-validation score Accuracy : 0.76
Precision: 0.14871794871794872
Recall: 0.9354838709677419
F1 Score: 0.25663716814159293
Accuracy: 0.7236842105263158


0    413
1    195
dtype: int64

#### Trial 3: BRF

In [51]:
X_onehot = pd.get_dummies(X_top)
X_train, X_test, y_train, y_test = train_test_split(X_onehot, y, test_size=0.20, stratify=y, random_state=42)

ros = SMOTE(random_state=42)
X_resampled, y_resampled = ros.fit_resample(X_train, y_train)

brf = BalancedRandomForestClassifier(n_estimators=100)
brf.fit(X_resampled, y_resampled)

evaluationCV(brf,X_resampled, y_resampled)

y_pred_brf = evaluationTest(brf,X_test, y_test)
y_pred_brf.value_counts()





















































































































































































































































Cross-validation scores Precision    : [0.93275488 0.94409938 0.95560254 0.94020619 0.94421488]
Cross-validation scores Recall       : [0.93058568 0.99132321 0.97830803 0.98915401 0.99566161]
Cross-validation scores F1           : [0.93376764 0.96698616 0.96382979 0.97343252 0.96733404]
Cross-validation scores Accuracy     : [0.92950108 0.96420824 0.96637744 0.95986985 0.96420824]
Mean cross-validation score Precision: 0.943375571634817
Mean cross-validation score Recall   : 0.977006507592191
Mean cross-validation score F1       : 0.9610700282016775
Mean cross-validation score Accuracy : 0.9568329718004339
Precision: 0.225
Recall: 0.2903225806451613
F1 Score: 0.2535211267605634
Accuracy: 0.912828947368421


0    568
1     40
dtype: int64

In [52]:
X_onehot = pd.get_dummies(X_top)
X_train, X_test, y_train, y_test = train_test_split(X_onehot, y, test_size=0.20, stratify=y, random_state=42)

undersampler = RandomUnderSampler()
X_resampled, y_resampled = undersampler.fit_resample(X_train, y_train)

brf = BalancedRandomForestClassifier(n_estimators=100)
brf.fit(X_resampled, y_resampled)

evaluationCV(brf,X_resampled, y_resampled)

y_pred_brf = evaluationTest(brf,X_test, y_test)
y_pred_brf.value_counts()



























































Cross-validation scores Precision    : [0.85714286 0.71428571 0.68965517 0.67857143 0.65217391]
Cross-validation scores Recall       : [0.68 0.8  0.8  0.8  0.64]
Cross-validation scores F1           : [0.75       0.73076923 0.75471698 0.75471698 0.65306122]
Cross-validation scores Accuracy     : [0.78 0.7  0.76 0.76 0.66]
Mean cross-validation score Precision: 0.7183658170914544
Mean cross-validation score Recall   : 0.744
Mean cross-validation score F1       : 0.7286528835046356
Mean cross-validation score Accuracy : 0.732
Precision: 0.15135135135135136
Recall: 0.9032258064516129
F1 Score: 0.2592592592592593
Accuracy: 0.7368421052631579




0    423
1    185
dtype: int64

# 2. Experiments

demographics_windowed

merged_table

--------------------------------
static_demo_comorb

labs_windowed

vitals_windowed

In [53]:
age_ranges = [0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 400]
age_labels = ['0-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60-69', '70-79', '80-89', '90+']
demographics_windowed['age_group'] = pd.cut(merged_table['admission_age'], bins=age_ranges, labels=age_labels, right=False)

In [54]:
static_demo_comorb = demographics_windowed[['icustay_id','gender', 'ethnicity_grouped', 'label_cor_art', 'diabetes_mellitus', 'anemia_flag', 'age_group', 'ckd']]

In [55]:
# Get the unique icustay_id values from each DataFrame
icustay_id_df1 = set(static_demo_comorb['icustay_id'])
icustay_id_df2 = set(labs_windowed['icustay_id'])
icustay_id_df3 = set(vitals_windowed['icustay_id'])

# Check for missing icustay_id values
missing_from_df1 = icustay_id_df2.union(icustay_id_df3) - icustay_id_df1
missing_from_df2 = icustay_id_df1.union(icustay_id_df3) - icustay_id_df2
missing_from_df3 = icustay_id_df1.union(icustay_id_df2) - icustay_id_df3

# Print the missing icustay_id values
if missing_from_df1:
    print(f"Icustay_id missing from static_demo_comorb: {missing_from_df1}")
else:
    print("No icustay_id missing from static_demo_comorb")

if missing_from_df2:
    print(f"Icustay_id missing from labs_windowed: {missing_from_df2}")
else:
    print("No icustay_id missing from labs_windowed")

if missing_from_df3:
    print(f"Icustay_id missing from vitals_windowed: {missing_from_df3}")
else:
    print("No icustay_id missing from vitals_windowed")



No icustay_id missing from static_demo_comorb
No icustay_id missing from labs_windowed
No icustay_id missing from vitals_windowed


In [56]:
# Calculate the difference between max and min charttime for labs_windowed
labs_diff = labs_windowed.groupby('icustay_id')['charttime'].apply(lambda x: x.max() - x.min()).to_frame()
labs_diff.min()

charttime   2 days 16:00:00
dtype: timedelta64[ns]

In [57]:
# Calculate the difference between max and min charttime for vitals_windowed
vitals_diff = vitals_windowed.groupby('icustay_id')['charttime'].apply(lambda x: x.max() - x.min()).to_frame()
vitals_diff.min()

charttime   2 days 23:00:00
dtype: timedelta64[ns]

In [58]:
WINDOW_LENGTH_NEW = 96

In [59]:
filtered_icustay_ids = labs_diff[(labs_diff == pd.Timedelta(hours=WINDOW_LENGTH_NEW)) & (vitals_diff == pd.Timedelta(hours=WINDOW_LENGTH_NEW))].index.tolist()
print("Total : ",len(filtered_icustay_ids))
print(static_demo_comorb['ckd'].value_counts())

Total :  8405
0    7868
1     537
Name: ckd, dtype: int64


## 2.1 Random Forest - Comorbidity & Demographics

In [62]:
def RandomForestForMulti(X_train, X_test, y_train, y_test, resampling=None):        
    if resampling == None:
        print("No resampling: Train:", y_train.shape[0] , "Test:", y_test.shape[0])
    elif resampling.lower() == 'under':
        sampler = RandomUnderSampler()
        X_train, y_train = sampler.fit_resample(X_train, y_train)
        print("under sampling: Train:", y_train.shape[0] , "Test:", y_test.shape[0])
    elif resampling.lower() == 'over':
        sampler = SMOTE()
        X_train, y_train = sampler.fit_resample(X_train, y_train)
        print("over sampling: Train:", y_train.shape[0] , "Test:", y_test.shape[0])
    
    #rf_static_demo_comorb = RandomForestClassifier(n_estimators=300, random_state=42)
        
    rf_static_demo_comorb_best = RandomForestClassifier(n_estimators=200, 
                                                        max_depth=None,
                                                        min_samples_leaf=2,
                                                        min_samples_split=2,
                                                        max_features='sqrt',
                                                        random_state=42)
    
    rf_static_demo_comorb_best.fit(X_train, y_train)
    
    evaluationCV(rf_static_demo_comorb_best,X_train, y_train)
    
    y_pred = evaluationTest(rf_static_demo_comorb_best,X_test, y_test)
    
    cm  = confusion_matrix(y_test, y_pred)    
    print(cm)
    
    f1 = f1_score(y_test, y_pred)
    weight = np.log(f1/(1-f1))
    
    proba = rf_static_demo_comorb_best.predict_proba(X_test)
    
    return rf_static_demo_comorb_best, weight, proba    

In [63]:
param_grid = {
    'n_estimators': [100, 200, 300],
    'max_depth': [None, 5, 10],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4],
    'max_features': ['sqrt', 'log2']
}

X=static_demo_comorb.drop(['ckd','icustay_id'],axis=1)
y=static_demo_comorb['ckd']

X_onehot=pd.get_dummies(X)
X_train, X_test, y_train, y_test = train_test_split(X_onehot, y, test_size=0.2,stratify = y, random_state=42)
rf_default, _, _ = RandomForestForMulti(X_train, X_test, y_train, y_test)

No resampling: Train: 6724 Test: 1681


  _warn_prf(average, modifier, msg_start, len(result))


Cross-validation scores Precision    : [0. 1. 1. 0. 1.]
Cross-validation scores Recall       : [0.         0.01162791 0.02325581 0.         0.02325581]
Cross-validation scores F1           : [0.         0.02298851 0.04545455 0.         0.04545455]
Cross-validation scores Accuracy     : [0.93457249 0.93680297 0.93754647 0.93605948 0.9375    ]
Mean cross-validation score Precision: 0.6
Mean cross-validation score Recall   : 0.011627906976744186
Mean cross-validation score F1       : 0.022779519331243465
Mean cross-validation score Accuracy : 0.9364962825278811
Precision: 0.0
Recall: 0.0
F1 Score: 0.0
Accuracy: 0.9363474122546104
[[1574    0]
 [ 107    0]]


  _warn_prf(average, modifier, msg_start, len(result))
  weight = np.log(f1/(1-f1))


In [65]:
X=static_demo_comorb.drop(['ckd','icustay_id'],axis=1)
y=static_demo_comorb['ckd']

X_onehot=pd.get_dummies(X)
X_train, X_test, y_train, y_test = train_test_split(X_onehot, y, test_size=0.2,stratify = y, random_state=42)
rf_under_sampled, _, _ = RandomForestForMulti(X_train, X_test, y_train, y_test, resampling='under')

under sampling: Train: 860 Test: 1681
Cross-validation scores Precision    : [0.65714286 0.56756757 0.64864865 0.671875   0.64634146]
Cross-validation scores Recall       : [0.53488372 0.48837209 0.55813953 0.5        0.61627907]
Cross-validation scores F1           : [0.58974359 0.525      0.6        0.57333333 0.63095238]
Cross-validation scores Accuracy     : [0.62790698 0.55813953 0.62790698 0.62790698 0.63953488]
Mean cross-validation score Precision: 0.6383151073547415
Mean cross-validation score Recall   : 0.5395348837209303
Mean cross-validation score F1       : 0.5838058608058608
Mean cross-validation score Accuracy : 0.6162790697674418
Precision: 0.14285714285714285
Recall: 0.5794392523364486
F1 Score: 0.22920517560073936
Accuracy: 0.7519333729922665
[[1202  372]
 [  45   62]]


In [66]:
X=static_demo_comorb.drop(['ckd','icustay_id'],axis=1)
y=static_demo_comorb['ckd']

X_onehot=pd.get_dummies(X)
X_train, X_test, y_train, y_test = train_test_split(X_onehot, y, test_size=0.2,stratify = y, random_state=42)
rf_over_sampled, _, _ = RandomForestForMulti(X_train, X_test, y_train, y_test, resampling='over')

over sampling: Train: 12588 Test: 1681
Cross-validation scores Precision    : [0.71916509 0.75406504 0.75048924 0.70673953 0.71103008]
Cross-validation scores Recall       : [0.60206513 0.58935663 0.60921366 0.61685215 0.61953932]
Cross-validation scores F1           : [0.65542585 0.66161391 0.67251206 0.65874363 0.66213922]
Cross-validation scores Accuracy     : [0.68347895 0.69857029 0.70333598 0.68057211 0.6837505 ]
Mean cross-validation score Precision: 0.7282977942567195
Mean cross-validation score Recall   : 0.6074053776245057
Mean cross-validation score F1       : 0.6620869344438469
Mean cross-validation score Accuracy : 0.6899415665294899
Precision: 0.12129380053908356
Recall: 0.4205607476635514
F1 Score: 0.1882845188284519
Accuracy: 0.7691850089232599
[[1248  326]
 [  62   45]]


## 2.2 Time series - Data preparation

In [67]:
def print_unique_shape(grouped_data,feature_cols):
    previous_shape = []
    for _, group in grouped_data:
        group_values = group[feature_cols].values.T
        if group_values.shape not in previous_shape:        
            print(group_values.shape)
            previous_shape.append(group_values.shape)

In [68]:
def check_shape_in_grouped_df(data_grouped, feature_cols, icustay_id):
    for _, group in data_grouped:
        if(group['icustay_id'].values[0]==icustay_id):
            group_values = group[feature_cols].values.T
            return group_values.shape

In [69]:
def check_missing_and_extras(data_windowed,feature_cols, threshold):
    data_windowed_new = data_windowed.copy()
    df_counts = data_windowed_new.groupby('icustay_id').count()
    previous_shape = []
    
    icustay_ids_less_records = df_counts[df_counts['charttime'] < threshold].index
    icustay_ids_more_records = df_counts[df_counts['charttime'] > threshold].index
    icustay_ids_correct_records = df_counts[df_counts['charttime'] == threshold].index
     
    print("len(icustay_ids_fewer_records)",len(icustay_ids_less_records))
    print("len(icustay_ids_more_records)",len(icustay_ids_more_records))  
    print("len(icustay_ids_correct_records)",len(icustay_ids_correct_records))
        
    for icustay_id in icustay_ids_more_records:
        df_grouped = data_windowed_new[data_windowed_new['icustay_id'] == icustay_id]
        if df_grouped.shape not in previous_shape:        
            print(f"There are records with more than {threshold} readings : {df_grouped.shape}")
            previous_shape.append(df_grouped.shape)
        
        # Check if the time span is more than 4 days
        if (df_grouped['charttime'].max() - df_grouped['charttime'].min()).days > 4:
            print(f"icustay_id: {icustay_id} has a time span of more than 4 days.")
        
        # Check for duplicate records
        if df_grouped.duplicated().sum() > 0:
            print(f"icustay_id: {icustay_id} has {df_grouped.duplicated().sum()} duplicate records.")
            
    for icustay_id in icustay_ids_less_records:
        print(f"icustay_id: {icustay_id} has a time span less than {threshold} records")


In [70]:
def backward_forward_fill(data_windowed, time_interval, threshold, feature_cols):    
    data_windowed_new = data_windowed.copy()
    data_windowed_new['charttime'] = pd.to_datetime(data_windowed_new['charttime']) 
    data_windowed_new.sort_values(['icustay_id', 'charttime'])
    df_filled = data_windowed_new.groupby('icustay_id').apply(lambda group: group.bfill().ffill())
    return df_filled    

In [71]:
def create_threshold_records(df, time_interval, threshold, feature_cols):
    df_new = pd.DataFrame()
    for id, group in df.groupby('icustay_id'):
        if len(group) > threshold:
            group = group.head(threshold)
        elif len(group) < threshold:
            missing_rows_count = threshold - len(group)
            last_timestamp = group['charttime'].max()
            missing_rows_df = pd.DataFrame({
                'icustay_id': [id]*missing_rows_count,
                'charttime': pd.date_range(start=last_timestamp + pd.Timedelta(hours=time_interval), 
                                           periods=missing_rows_count, 
                                           freq=f'{time_interval}H'),
                'ckd': [group['ckd'].iloc[0]]*missing_rows_count
            })
            for col in feature_cols:
                missing_rows_df[col] = np.nan
            group = pd.concat([group, missing_rows_df])
        df_new = pd.concat([df_new, group])
    df_new.sort_values(['icustay_id', 'charttime'], inplace=True)
    df_new.reset_index(drop=True, inplace=True)
    return df_new

### 2.2.1 Labs

In [72]:
feature_labs= ['aniongap', 'albumin', 'bands',
       'bicarbonate', 'bilirubin', 'creatinine', 'chloride', 'glucose',
       'hematocrit', 'hemoglobin', 'lactate', 'platelet', 'potassium', 'ptt',
       'inr', 'pt', 'sodium', 'bun', 'wbc', 'bacteria']

In [73]:
labs_windowed_12 = create_threshold_records(labs_windowed, 8, 12, feature_labs)
labs_windowed_12 = backward_forward_fill(labs_windowed_12, 8, 12, feature_labs)

labs_windowed_12['charttime'] = pd.to_datetime(labs_windowed_12['charttime']) 
labs_windowed_12.sort_values(['icustay_id', 'charttime'])
labs_grouped = labs_windowed_12[['icustay_id'] + feature_labs + ['ckd']].groupby(['icustay_id'])

To preserve the previous behavior, use

	>>> .groupby(..., group_keys=False)


	>>> .groupby(..., group_keys=True)
  df_filled = data_windowed_new.groupby('icustay_id').apply(lambda group: group.bfill().ffill())


In [74]:
check_missing_and_extras(labs_windowed_12,feature_labs, 12)

len(icustay_ids_fewer_records) 0
len(icustay_ids_more_records) 0
len(icustay_ids_correct_records) 8405


In [75]:
print_unique_shape(labs_grouped,feature_labs)

(20, 12)


  for _, group in grouped_data:


### 2.2.2 Vitals

In [76]:
feature_vitals = ['heartrate', 'sysbp', 'diasbp', 'meanbp', 'resprate', 
                  'tempc', 'spo2', 'glucose', 'rbc', 'specificgravity', 'appetite_median']

vitals_windowed_96 = create_threshold_records(vitals_windowed, 1, 96, feature_vitals)
vitals_windowed_96 = backward_forward_fill(vitals_windowed_96, 1, 96, feature_vitals)

vitals_windowed_96['charttime'] = pd.to_datetime(vitals_windowed_96['charttime']) 
vitals_windowed_96.sort_values(['icustay_id', 'charttime'])
vitals_grouped = vitals_windowed_96[['icustay_id'] + feature_vitals + ['ckd']].groupby(['icustay_id'])

To preserve the previous behavior, use

	>>> .groupby(..., group_keys=False)


	>>> .groupby(..., group_keys=True)
  df_filled = data_windowed_new.groupby('icustay_id').apply(lambda group: group.bfill().ffill())


In [77]:
check_missing_and_extras(vitals_windowed_96,feature_vitals, 96)

len(icustay_ids_fewer_records) 0
len(icustay_ids_more_records) 0
len(icustay_ids_correct_records) 8405


In [78]:
print_unique_shape(vitals_grouped,feature_vitals)

(11, 96)


  for _, group in grouped_data:


## 2.3 Rocket - Time series model

In [79]:
print("labs_grouped['icustay_id'].nunique(): ",len(labs_grouped['icustay_id'].nunique()))
print("vitals_grouped['icustay_id'].nunique(): ",len(vitals_grouped['icustay_id'].nunique()))

labs_grouped['icustay_id'].nunique():  8405
vitals_grouped['icustay_id'].nunique():  8405


In [133]:
def RocketMulti(grouped_data,feature_columns, num_kernels=100, resampling=None, filtered_test_ids = None):  
    import warnings
    import logging
    warnings.filterwarnings('ignore')   
    logging.getLogger().setLevel(logging.ERROR)
    
    X = []
    y = []
    for icustay_id, group in grouped_data:
        group_values = group[feature_columns].values.T
        num_timestamps = group_values.shape[1]
        if filtered_test_ids != None:
            if icustay_id in filtered_test_ids:                
                X_test.append(group_values)
                y_test.append(group['ckd'].iloc[0])
            else:
                X_train.append(group_values)
                y_train.append(group['ckd'].iloc[0])
        else:
            X.append(group_values)
            y.append(group['ckd'].iloc[0])        
          
    if filtered_test_ids != None:
        X_train = np.array(X_train)
        n_samples, n_features, n_channels = X_train.shape
        X_train_2d = X_train.reshape((n_samples, n_features*n_channels))        
        y_train = np.array(y_train)
        
        X_test = np.array(X_test)
        n_samples, n_features, n_channels = X_test.shape
        X_test_2d = X_test.reshape((n_samples, n_features*n_channels)) 
        y_test = np.array(y_test)
    else:    
        X = np.array(X)
        y = np.array(y)
        n_samples, n_features, n_channels = X.shape
        X_2d = X.reshape((n_samples, n_features*n_channels))          
        X_train_2d, X_test_2d, y_train, y_test = train_test_split(X_2d, y, test_size=0.2, stratify=y)
    
    if resampling == None:
        print("No resampling: Train:", y_train.shape[0] , "Test:", y_test.shape[0])
    elif (resampling.lower() == 'under'):
        sampler = RandomUnderSampler(random_state=42)
        X_train_2d, y_train = sampler.fit_resample(X_train_2d, y_train)
        print("Under sampling: Train:", y_train.shape[0] , "Test:", y_test.shape[0])
    elif (resampling.lower() == 'over'):
        sampler = SMOTE(random_state=42)
        X_train_2d, y_train = sampler.fit_resample(X_train_2d, y_train)
        print("Over sampling: Train:", y_train.shape[0] , "Test:", y_test.shape[0])
    
    X_train = X_train_2d.reshape((X_train_2d.shape[0], n_features, n_channels))
    X_test = X_test_2d.reshape((X_test_2d.shape[0], n_features, n_channels))

    print("X_train shape: ",X_train.shape,"\ny_train shape: ",y_train.shape)
    
    rocket = Rocket(num_kernels, random_state=42)
    rocket.fit(X_train)
    
    X_train_transformed = rocket.transform(X_train)
    X_test_transformed = rocket.transform(X_test)

    # Reshape transformed data back into 2D for logistic regression
    X_train_transformed_2d = X_train_transformed.values.reshape((X_train_transformed.shape[0], -1))
    X_test_transformed_2d = X_test_transformed.values.reshape((X_test_transformed.shape[0], -1))
    
    clf = LogisticRegression(random_state=42, max_iter=1000)
    clf.fit(X_train_transformed_2d, y_train)
    
    evaluationCV(clf,X_train_transformed_2d, y_train)
    y_pred = evaluationTest(clf,X_test_transformed_2d, y_test)
    
    cm  = confusion_matrix(y_test, y_pred)
    print(cm)
    
    f1 = f1_score(y_test, y_pred)
    weight = np.log(f1/(1-f1))
    
    proba = clf.predict_proba(X_test_transformed_2d)
    
    return clf, weight, proba

### 2.3.1 Rocket for Labs

In [134]:
clf_lab_default, _, _ = RocketMulti(labs_grouped,feature_labs)

AttributeError: 'NoneType' object has no attribute 'empty'

#### 2.3.1.1 Undersampling -labs

In [82]:
clf_lab_under_sampled_default, _, _ = RocketMulti(labs_grouped,feature_labs, resampling = "Under")

Under sampling: Train: 860 Test: 1681
X_train shape:  (860, 20, 12) 
y_train shape:  (860,)
Cross-validation scores Precision    : [0.65591398 0.57692308 0.62790698 0.60493827 0.67010309]
Cross-validation scores Recall       : [0.70930233 0.69767442 0.62790698 0.56976744 0.75581395]
Cross-validation scores F1           : [0.68156425 0.63157895 0.62790698 0.58682635 0.71038251]
Cross-validation scores Accuracy     : [0.66860465 0.59302326 0.62790698 0.59883721 0.69186047]
Mean cross-validation score Precision: 0.627157079310066
Mean cross-validation score Recall   : 0.672093023255814
Mean cross-validation score F1       : 0.6476518061778509
Mean cross-validation score Accuracy : 0.6360465116279069
Precision: 0.10384068278805121
Recall: 0.6822429906542056
F1 Score: 0.18024691358024691
Accuracy: 0.6049970255800119
[[944 630]
 [ 34  73]]


In [83]:
clf_lab_under_sampled_10, _, _ = RocketMulti(labs_grouped,feature_labs, num_kernels = 10, resampling = "Under")

Under sampling: Train: 860 Test: 1681
X_train shape:  (860, 20, 12) 
y_train shape:  (860,)
Cross-validation scores Precision    : [0.55769231 0.56435644 0.54285714 0.60952381 0.60377358]
Cross-validation scores Recall       : [0.6744186  0.6627907  0.6627907  0.74418605 0.74418605]
Cross-validation scores F1           : [0.61052632 0.60962567 0.59685864 0.67015707 0.66666667]
Cross-validation scores Accuracy     : [0.56976744 0.5755814  0.55232558 0.63372093 0.62790698]
Mean cross-validation score Precision: 0.5756406561244969
Mean cross-validation score Recall   : 0.6976744186046512
Mean cross-validation score F1       : 0.6307668715423241
Mean cross-validation score Accuracy : 0.5918604651162791
Precision: 0.0906832298136646
Recall: 0.6822429906542056
F1 Score: 0.1600877192982456
Accuracy: 0.5443188578227246
[[842 732]
 [ 34  73]]


In [84]:
clf_lab_under_sampled_25, _, _ = RocketMulti(labs_grouped,feature_labs, num_kernels = 25, resampling = "Under")

Under sampling: Train: 860 Test: 1681
X_train shape:  (860, 20, 12) 
y_train shape:  (860,)
Cross-validation scores Precision    : [0.6626506  0.69318182 0.57575758 0.56989247 0.64485981]
Cross-validation scores Recall       : [0.63953488 0.70930233 0.6627907  0.61627907 0.80232558]
Cross-validation scores F1           : [0.65088757 0.70114943 0.61621622 0.59217877 0.71502591]
Cross-validation scores Accuracy     : [0.65697674 0.69767442 0.5872093  0.5755814  0.68023256]
Mean cross-validation score Precision: 0.6292684565102847
Mean cross-validation score Recall   : 0.6860465116279071
Mean cross-validation score F1       : 0.6550915786307082
Mean cross-validation score Accuracy : 0.6395348837209303
Precision: 0.10463576158940398
Recall: 0.7383177570093458
F1 Score: 0.18329466357308585
Accuracy: 0.5812016656751934
[[898 676]
 [ 28  79]]


In [85]:
clf_lab_under_sampled_50, _, _ = RocketMulti(labs_grouped,feature_labs, num_kernels = 50, resampling = "Under")

Under sampling: Train: 860 Test: 1681
X_train shape:  (860, 20, 12) 
y_train shape:  (860,)
Cross-validation scores Precision    : [0.63095238 0.61764706 0.65882353 0.61797753 0.67021277]
Cross-validation scores Recall       : [0.61627907 0.73255814 0.65116279 0.63953488 0.73255814]
Cross-validation scores F1           : [0.62352941 0.67021277 0.65497076 0.62857143 0.7       ]
Cross-validation scores Accuracy     : [0.62790698 0.63953488 0.65697674 0.62209302 0.68604651]
Mean cross-validation score Precision: 0.639122652647002
Mean cross-validation score Recall   : 0.6744186046511629
Mean cross-validation score F1       : 0.6554568733055
Mean cross-validation score Accuracy : 0.6465116279069767
Precision: 0.09880239520958084
Recall: 0.616822429906542
F1 Score: 0.1703225806451613
Accuracy: 0.6174895895300416
[[972 602]
 [ 41  66]]


In [86]:
clf_lab_under_sampled_75, _, _ = RocketMulti(labs_grouped,feature_labs, num_kernels = 75, resampling = "Under")

Under sampling: Train: 860 Test: 1681
X_train shape:  (860, 20, 12) 
y_train shape:  (860,)
Cross-validation scores Precision    : [0.63       0.65591398 0.64444444 0.58064516 0.60215054]
Cross-validation scores Recall       : [0.73255814 0.70930233 0.6744186  0.62790698 0.65116279]
Cross-validation scores F1           : [0.67741935 0.68156425 0.65909091 0.60335196 0.62569832]
Cross-validation scores Accuracy     : [0.65116279 0.66860465 0.65116279 0.5872093  0.61046512]
Mean cross-validation score Precision: 0.6226308243727599
Mean cross-validation score Recall   : 0.6790697674418604
Mean cross-validation score F1       : 0.6494249578138568
Mean cross-validation score Accuracy : 0.6337209302325582
Precision: 0.1016949152542373
Recall: 0.6728971962616822
F1 Score: 0.17668711656441718
Accuracy: 0.6008328375966686
[[938 636]
 [ 35  72]]


#### 2.3.1.2 Oversampling - labs

In [87]:
clf_lab_over_sampled_default, _, _ = RocketMulti(labs_grouped,feature_labs, resampling = "Over")

Over sampling: Train: 12588 Test: 1681
X_train shape:  (12588, 20, 12) 
y_train shape:  (12588,)
Cross-validation scores Precision    : [0.73766641 0.74481773 0.75439883 0.73296628 0.7375    ]
Cross-validation scores Recall       : [0.74821287 0.82764098 0.81731533 0.84658188 0.84352661]
Cross-validation scores F1           : [0.74290221 0.78404816 0.78459779 0.78568794 0.78695813]
Cross-validation scores Accuracy     : [0.74106434 0.7720413  0.77561557 0.76916965 0.77155344]
Mean cross-validation score Precision: 0.7414698472468935
Mean cross-validation score Recall   : 0.8166555332606821
Mean cross-validation score F1       : 0.7768388448347945
Mean cross-validation score Accuracy : 0.7658888580685493
Precision: 0.11506276150627615
Recall: 0.514018691588785
F1 Score: 0.18803418803418803
Accuracy: 0.7174301011302796
[[1151  423]
 [  52   55]]


In [88]:
clf_lab_over_sampled_5, _, _ = RocketMulti(labs_grouped,feature_labs, num_kernels = 5, resampling = "Over")

Over sampling: Train: 12588 Test: 1681
X_train shape:  (12588, 20, 12) 
y_train shape:  (12588,)
Cross-validation scores Precision    : [0.61526946 0.6173913  0.61751497 0.62727936 0.62178072]
Cross-validation scores Recall       : [0.65289913 0.67672756 0.65528197 0.6836248  0.67116759]
Cross-validation scores F1           : [0.63352601 0.64569913 0.63583815 0.65424116 0.64553094]
Cross-validation scores Accuracy     : [0.6223193  0.62867355 0.62470214 0.63885578 0.63130711]
Mean cross-validation score Precision: 0.6198471629473541
Mean cross-validation score Recall   : 0.6679402104529423
Mean cross-validation score F1       : 0.6429670772578507
Mean cross-validation score Accuracy : 0.6291715776721472
Precision: 0.09344490934449093
Recall: 0.6261682242990654
F1 Score: 0.16262135922330098
Accuracy: 0.5895300416418798
[[924 650]
 [ 40  67]]


In [89]:
clf_lab_over_sampled_10, _, _ = RocketMulti(labs_grouped,feature_labs, num_kernels = 10, resampling = "Over")

Over sampling: Train: 12588 Test: 1681
X_train shape:  (12588, 20, 12) 
y_train shape:  (12588,)
Cross-validation scores Precision    : [0.62947067 0.63574814 0.64587814 0.6565097  0.64388243]
Cross-validation scores Recall       : [0.69896743 0.74583002 0.71564734 0.75357711 0.74821287]
Cross-validation scores F1           : [0.6624012  0.68640351 0.67897513 0.70170244 0.69213813]
Cross-validation scores Accuracy     : [0.64376489 0.65925338 0.66163622 0.67977751 0.66706397]
Mean cross-validation score Precision: 0.6422978150709253
Mean cross-validation score Recall   : 0.7324469542663253
Mean cross-validation score F1       : 0.6843240842748912
Mean cross-validation score Accuracy : 0.6622991931277163
Precision: 0.08669656203288491
Recall: 0.5420560747663551
F1 Score: 0.14948453608247425
Accuracy: 0.6073765615704938
[[963 611]
 [ 49  58]]


In [90]:
clf_lab_over_sampled_25, _, _ = RocketMulti(labs_grouped,feature_labs, num_kernels = 25, resampling = "Over")

Over sampling: Train: 12588 Test: 1681
X_train shape:  (12588, 20, 12) 
y_train shape:  (12588,)
Cross-validation scores Precision    : [0.71283255 0.71161826 0.69678303 0.71886121 0.72561863]
Cross-validation scores Recall       : [0.72359015 0.81731533 0.80857824 0.80286169 0.79189833]
Cross-validation scores F1           : [0.71817107 0.76081331 0.74852941 0.758543   0.75731105]
Cross-validation scores Accuracy     : [0.71604448 0.74305004 0.72835584 0.74453715 0.74612634]
Mean cross-validation score Precision: 0.7131427350287665
Mean cross-validation score Recall   : 0.7888487468920118
Mean cross-validation score F1       : 0.7486735674590708
Mean cross-validation score Accuracy : 0.7356227691412454
Precision: 0.10185185185185185
Recall: 0.514018691588785
F1 Score: 0.17001545595054096
Accuracy: 0.6805472932778108
[[1089  485]
 [  52   55]]


In [91]:
clf_lab_over_sampled_50, _, _ = RocketMulti(labs_grouped,feature_labs, num_kernels = 50, resampling = "Over")

Over sampling: Train: 12588 Test: 1681
X_train shape:  (12588, 20, 12) 
y_train shape:  (12588,)
Cross-validation scores Precision    : [0.72782875 0.73821253 0.73973556 0.72047519 0.73710602]
Cross-validation scores Recall       : [0.75615568 0.83320095 0.84432089 0.81955485 0.81731533]
Cross-validation scores F1           : [0.74172185 0.78283582 0.78857567 0.76682782 0.77514124]
Cross-validation scores Accuracy     : [0.73669579 0.76886418 0.77362986 0.75089392 0.76281287]
Mean cross-validation score Precision: 0.7326716084254739
Mean cross-validation score Recall   : 0.8141095400872066
Mean cross-validation score F1       : 0.7710204805652303
Mean cross-validation score Accuracy : 0.7585793254006197
Precision: 0.10877862595419847
Recall: 0.5327102803738317
F1 Score: 0.18066561014263072
Accuracy: 0.6924449732302201
[[1107  467]
 [  50   57]]


In [92]:
clf_lab_over_sampled_75, _, _ = RocketMulti(labs_grouped,feature_labs, num_kernels = 75, resampling = "Over")

Over sampling: Train: 12588 Test: 1681
X_train shape:  (12588, 20, 12) 
y_train shape:  (12588,)
Cross-validation scores Precision    : [0.75378486 0.75759717 0.75197133 0.76483516 0.75807609]
Cross-validation scores Recall       : [0.75138999 0.85146942 0.83320095 0.82988871 0.83876092]
Cross-validation scores F1           : [0.75258552 0.80179506 0.7905049  0.79603507 0.79638009]
Cross-validation scores Accuracy     : [0.75297855 0.78951549 0.77918983 0.78744537 0.78545888]
Mean cross-validation score Precision: 0.7572529238924393
Mean cross-validation score Recall   : 0.8209419997954315
Mean cross-validation score F1       : 0.787460129552794
Mean cross-validation score Accuracy : 0.7789176254369414
Precision: 0.10395010395010396
Recall: 0.4672897196261682
F1 Score: 0.17006802721088438
Accuracy: 0.7096966091612136
[[1143  431]
 [  57   50]]


### 2.3.2 Rocket for Vitals

In [109]:
clf_vitals_default, _, _ = RocketMulti(vitals_grouped,feature_vitals)

No resampling: Train: 6724 Test: 1681
X_train shape:  (6724, 11, 96) 
y_train shape:  (6724,)
Cross-validation scores Precision    : [0. 0. 0. 0. 0.]
Cross-validation scores Recall       : [0. 0. 0. 0. 0.]
Cross-validation scores F1           : [0. 0. 0. 0. 0.]
Cross-validation scores Accuracy     : [0.93531599 0.93531599 0.93605948 0.93457249 0.9360119 ]
Mean cross-validation score Precision: 0.0
Mean cross-validation score Recall   : 0.0
Mean cross-validation score F1       : 0.0
Mean cross-validation score Accuracy : 0.9354551690564701
Precision: 1.0
Recall: 0.009345794392523364
F1 Score: 0.018518518518518517
Accuracy: 0.9369422962522308
[[1574    0]
 [ 106    1]]


#### 2.3.2.1 Undersampling - Vitals

In [110]:
clf_vitals_under_sampled_default, _, _ = RocketMulti(vitals_grouped,feature_vitals, resampling="Under")

Under sampling: Train: 860 Test: 1681
X_train shape:  (860, 11, 96) 
y_train shape:  (860,)
Cross-validation scores Precision    : [0.62068966 0.56701031 0.63265306 0.57731959 0.58823529]
Cross-validation scores Recall       : [0.62790698 0.63953488 0.72093023 0.65116279 0.58139535]
Cross-validation scores F1           : [0.62427746 0.6010929  0.67391304 0.61202186 0.58479532]
Cross-validation scores Accuracy     : [0.62209302 0.5755814  0.65116279 0.5872093  0.5872093 ]
Mean cross-validation score Precision: 0.5971815814843535
Mean cross-validation score Recall   : 0.644186046511628
Mean cross-validation score F1       : 0.6192201151722895
Mean cross-validation score Accuracy : 0.6046511627906976
Precision: 0.09115281501340483
Recall: 0.6355140186915887
F1 Score: 0.15943728018757328
Accuracy: 0.5734681737061273
[[896 678]
 [ 39  68]]


In [111]:
clf_vitals_under_sampled_5, _, _ = RocketMulti(vitals_grouped,feature_vitals, num_kernels = 5, resampling = "Under")

Under sampling: Train: 860 Test: 1681
X_train shape:  (860, 11, 96) 
y_train shape:  (860,)
Cross-validation scores Precision    : [0.60655738 0.65384615 0.57894737 0.54285714 0.55714286]
Cross-validation scores Recall       : [0.86046512 0.79069767 0.89534884 0.88372093 0.90697674]
Cross-validation scores F1           : [0.71153846 0.71578947 0.70319635 0.67256637 0.69026549]
Cross-validation scores Accuracy     : [0.65116279 0.68604651 0.62209302 0.56976744 0.59302326]
Mean cross-validation score Precision: 0.5878701798632774
Mean cross-validation score Recall   : 0.8674418604651162
Mean cross-validation score F1       : 0.698671228132343
Mean cross-validation score Accuracy : 0.6244186046511628
Precision: 0.08999081726354453
Recall: 0.9158878504672897
F1 Score: 0.16387959866220736
Accuracy: 0.405116002379536
[[583 991]
 [  9  98]]


In [112]:
clf_vitals_under_sampled_10, _, _ = RocketMulti(vitals_grouped,feature_vitals, num_kernels = 10, resampling = "Under")

Under sampling: Train: 860 Test: 1681
X_train shape:  (860, 11, 96) 
y_train shape:  (860,)
Cross-validation scores Precision    : [0.57522124 0.56034483 0.5982906  0.58       0.61682243]
Cross-validation scores Recall       : [0.75581395 0.75581395 0.81395349 0.6744186  0.76744186]
Cross-validation scores F1           : [0.65326633 0.64356436 0.68965517 0.62365591 0.68393782]
Cross-validation scores Accuracy     : [0.59883721 0.58139535 0.63372093 0.59302326 0.64534884]
Mean cross-validation score Precision: 0.58613581894428
Mean cross-validation score Recall   : 0.7534883720930232
Mean cross-validation score F1       : 0.6588159196640839
Mean cross-validation score Accuracy : 0.6104651162790697
Precision: 0.09144893111638955
Recall: 0.719626168224299
F1 Score: 0.16227608008429928
Accuracy: 0.5270672218917312
[[809 765]
 [ 30  77]]


In [113]:
clf_vitals_under_sampled_25, _, _ = RocketMulti(vitals_grouped,feature_vitals, num_kernels = 25, resampling = "Under")

Under sampling: Train: 860 Test: 1681
X_train shape:  (860, 11, 96) 
y_train shape:  (860,)
Cross-validation scores Precision    : [0.60169492 0.60683761 0.625      0.62857143 0.59292035]
Cross-validation scores Recall       : [0.8255814  0.8255814  0.81395349 0.76744186 0.77906977]
Cross-validation scores F1           : [0.69607843 0.69950739 0.70707071 0.69109948 0.67336683]
Cross-validation scores Accuracy     : [0.63953488 0.64534884 0.6627907  0.65697674 0.62209302]
Mean cross-validation score Precision: 0.6110048609291147
Mean cross-validation score Recall   : 0.8023255813953488
Mean cross-validation score F1       : 0.6934245676432925
Mean cross-validation score Accuracy : 0.6453488372093024
Precision: 0.09129967776584318
Recall: 0.794392523364486
F1 Score: 0.16377649325626206
Accuracy: 0.48364069006543725
[[728 846]
 [ 22  85]]


In [114]:
clf_vitals_under_sampled_50, _, _ = RocketMulti(vitals_grouped,feature_vitals, num_kernels = 50, resampling = "Under")

Under sampling: Train: 860 Test: 1681
X_train shape:  (860, 11, 96) 
y_train shape:  (860,)
Cross-validation scores Precision    : [0.64516129 0.54464286 0.58163265 0.58415842 0.61363636]
Cross-validation scores Recall       : [0.69767442 0.70930233 0.6627907  0.68604651 0.62790698]
Cross-validation scores F1           : [0.67039106 0.61616162 0.61956522 0.63101604 0.62068966]
Cross-validation scores Accuracy     : [0.65697674 0.55813953 0.59302326 0.59883721 0.61627907]
Mean cross-validation score Precision: 0.5938463160009221
Mean cross-validation score Recall   : 0.6767441860465115
Mean cross-validation score F1       : 0.6315647185917193
Mean cross-validation score Accuracy : 0.6046511627906976
Precision: 0.09768637532133675
Recall: 0.7102803738317757
F1 Score: 0.17175141242937852
Accuracy: 0.5639500297441998
[[872 702]
 [ 31  76]]


In [115]:
clf_vitals_under_sampled_75, _, _ = RocketMulti(vitals_grouped,feature_vitals, num_kernels = 75, resampling = "Under")

Under sampling: Train: 860 Test: 1681
X_train shape:  (860, 11, 96) 
y_train shape:  (860,)
Cross-validation scores Precision    : [0.62068966 0.625      0.62376238 0.61764706 0.60747664]
Cross-validation scores Recall       : [0.62790698 0.63953488 0.73255814 0.73255814 0.75581395]
Cross-validation scores F1           : [0.62427746 0.63218391 0.67379679 0.67021277 0.67357513]
Cross-validation scores Accuracy     : [0.62209302 0.62790698 0.64534884 0.63953488 0.63372093]
Mean cross-validation score Precision: 0.6189151451495171
Mean cross-validation score Recall   : 0.6976744186046512
Mean cross-validation score F1       : 0.6548092103256703
Mean cross-validation score Accuracy : 0.6337209302325582
Precision: 0.09523809523809523
Recall: 0.6915887850467289
F1 Score: 0.167420814479638
Accuracy: 0.5621653777513385
[[871 703]
 [ 33  74]]


#### 2.3.2.2 oversamplimg for vitals

In [116]:
clf_vitals_over_sampled_default, _, _ = RocketMulti(vitals_grouped,feature_vitals, resampling = "Over")

Over sampling: Train: 12588 Test: 1681
X_train shape:  (12588, 11, 96) 
y_train shape:  (12588,)
Cross-validation scores Precision    : [0.71325648 0.72411397 0.73310345 0.742715   0.73347398]
Cross-validation scores Recall       : [0.78633836 0.82764098 0.84432089 0.83068362 0.82843527]
Cross-validation scores F1           : [0.74801662 0.77242402 0.78479144 0.78424015 0.77806789]
Cross-validation scores Accuracy     : [0.73510723 0.75615568 0.76846704 0.77155344 0.76360747]
Mean cross-validation score Precision: 0.7293325754429675
Mean cross-validation score Recall   : 0.8234838258339636
Mean cross-validation score F1       : 0.7735080223096197
Mean cross-validation score Accuracy : 0.7589781700481206
Precision: 0.10984848484848485
Recall: 0.5420560747663551
F1 Score: 0.1826771653543307
Accuracy: 0.6912552052349792
[[1104  470]
 [  49   58]]


In [117]:
clf_vitals_over_sampled_5, _, _ = RocketMulti(vitals_grouped,feature_vitals, num_kernels = 5, resampling = "Over")

Over sampling: Train: 12588 Test: 1681
X_train shape:  (12588, 11, 96) 
y_train shape:  (12588,)
Cross-validation scores Precision    : [0.58164852 0.58395324 0.59495003 0.58539149 0.59429477]
Cross-validation scores Recall       : [0.89118348 0.87291501 0.89833201 0.88553259 0.89356632]
Cross-validation scores F1           : [0.70388959 0.69977714 0.71582278 0.70484024 0.71383249]
Cross-validation scores Accuracy     : [0.62509929 0.62549643 0.64336775 0.62932062 0.64163687]
Mean cross-validation score Precision: 0.588047609483367
Mean cross-validation score Recall   : 0.8883058828580485
Mean cross-validation score F1       : 0.7076324479070407
Mean cross-validation score Accuracy : 0.6329841904280441
Precision: 0.08736059479553904
Recall: 0.8785046728971962
F1 Score: 0.15891800507185125
Accuracy: 0.4080904223676383
[[592 982]
 [ 13  94]]


In [118]:
clf_vitals_over_sampled_10, _, _ = RocketMulti(vitals_grouped,feature_vitals, num_kernels = 10, resampling = "Over")

Over sampling: Train: 12588 Test: 1681
X_train shape:  (12588, 11, 96) 
y_train shape:  (12588,)
Cross-validation scores Precision    : [0.62586315 0.62044653 0.61921708 0.62317961 0.61982507]
Cross-validation scores Recall       : [0.79189833 0.83876092 0.82922955 0.81572677 0.84499205]
Cross-validation scores F1           : [0.69915849 0.71327254 0.7089983  0.70657035 0.71510259]
Cross-validation scores Accuracy     : [0.65925338 0.66282764 0.65965052 0.66110449 0.66348828]
Mean cross-validation score Precision: 0.6217062902328416
Mean cross-validation score Recall   : 0.8241215237570888
Mean cross-validation score F1       : 0.708620453590815
Mean cross-validation score Accuracy : 0.6612648604264629
Precision: 0.09038901601830664
Recall: 0.7383177570093458
F1 Score: 0.16106014271151886
Accuracy: 0.5104104699583581
[[779 795]
 [ 28  79]]


In [119]:
clf_vitals_over_sampled_25, _, _ = RocketMulti(vitals_grouped,feature_vitals, num_kernels = 25, resampling = "Over")

Over sampling: Train: 12588 Test: 1681
X_train shape:  (12588, 11, 96) 
y_train shape:  (12588,)
Cross-validation scores Precision    : [0.65037853 0.68362688 0.67340287 0.67620286 0.66358229]
Cross-validation scores Recall       : [0.75059571 0.83240667 0.82049245 0.82670906 0.7974583 ]
Cross-validation scores F1           : [0.69690265 0.75071633 0.73970641 0.74391989 0.72438672]
Cross-validation scores Accuracy     : [0.67355044 0.72359015 0.71127879 0.71553437 0.69646404]
Mean cross-validation score Precision: 0.669438683723307
Mean cross-validation score Recall   : 0.8055324398827647
Mean cross-validation score F1       : 0.7311264012124664
Mean cross-validation score Accuracy : 0.7040835582534397
Precision: 0.10195530726256984
Recall: 0.6822429906542056
F1 Score: 0.17739975698663427
Accuracy: 0.5972635336109459
[[931 643]
 [ 34  73]]


In [120]:
clf_vitals_over_sampled_50, _, _ = RocketMulti(vitals_grouped,feature_vitals, num_kernels = 50, resampling = "Over")

Over sampling: Train: 12588 Test: 1681
X_train shape:  (12588, 11, 96) 
y_train shape:  (12588,)
Cross-validation scores Precision    : [0.66505858 0.69364548 0.6755102  0.7006993  0.67453158]
Cross-validation scores Recall       : [0.76648133 0.82366958 0.78872121 0.79650238 0.7720413 ]
Cross-validation scores F1           : [0.71217712 0.75308642 0.7277391  0.74553571 0.72      ]
Cross-validation scores Accuracy     : [0.69023034 0.7299444  0.70492454 0.72824791 0.69964243]
Mean cross-validation score Precision: 0.6818890290630313
Mean cross-validation score Recall   : 0.7894831616179091
Mean cross-validation score F1       : 0.7317076708761847
Mean cross-validation score Accuracy : 0.710597926159305
Precision: 0.08641975308641975
Recall: 0.5233644859813084
F1 Score: 0.14834437086092714
Accuracy: 0.6174895895300416
[[982 592]
 [ 51  56]]


In [121]:
clf_vitals_over_sampled_75, _, _ = RocketMulti(vitals_grouped,feature_vitals, num_kernels = 75, resampling = "Over")

Over sampling: Train: 12588 Test: 1681
X_train shape:  (12588, 11, 96) 
y_train shape:  (12588,)
Cross-validation scores Precision    : [0.69530686 0.69163293 0.69753086 0.69121622 0.70277976]
Cross-validation scores Recall       : [0.76489277 0.8141382  0.80778396 0.81254964 0.78378378]
Cross-validation scores F1           : [0.72844175 0.74790223 0.7486198  0.74698795 0.74107478]
Cross-validation scores Accuracy     : [0.71485306 0.72557585 0.72875298 0.72467223 0.72626142]
Mean cross-validation score Precision: 0.695693325151342
Mean cross-validation score Recall   : 0.7966296717686709
Mean cross-validation score F1       : 0.7426053034712763
Mean cross-validation score Accuracy : 0.7240231083122456
Precision: 0.0986159169550173
Recall: 0.5327102803738317
F1 Score: 0.16642335766423358
Accuracy: 0.6603212373587151
[[1053  521]
 [  50   57]]


## 2.4 Multimodal

In [122]:
X=static_demo_comorb.drop(['ckd'],axis=1)
y=static_demo_comorb['ckd']

X_onehot=pd.get_dummies(X)
X_train, X_test, y_train, y_test = train_test_split(X_onehot, y, test_size=0.2,stratify = y, random_state=42)

mm_test_ids = X_test['icustay_id']

X_train =X_train.drop(['icustay_id'],axis=1)
X_test  =X_test.drop(['icustay_id'],axis=1)

In [123]:
#### Best model for static - Under sampled Random forest model
rf_mm, weight_static, prob_static = RandomForestForMulti(X_train, X_test, y_train, y_test, resampling='under')

under sampling: Train: 860 Test: 1681
Cross-validation scores Precision    : [0.62820513 0.5875     0.72222222 0.7        0.61842105]
Cross-validation scores Recall       : [0.56976744 0.54651163 0.60465116 0.56976744 0.54651163]
Cross-validation scores F1           : [0.59756098 0.56626506 0.65822785 0.62820513 0.58024691]
Cross-validation scores Accuracy     : [0.61627907 0.58139535 0.68604651 0.6627907  0.60465116]
Mean cross-validation score Precision: 0.6512696806117859
Mean cross-validation score Recall   : 0.5674418604651164
Mean cross-validation score F1       : 0.6061011851474721
Mean cross-validation score Accuracy : 0.6302325581395347
Precision: 0.12667946257197696
Recall: 0.616822429906542
F1 Score: 0.21019108280254778
Accuracy: 0.7049375371802499
[[1119  455]
 [  41   66]]


In [128]:
mm_test_ids

7418    209543
5273    280892
3518    292064
2120    202033
6878    238436
         ...  
2904    252989
4861    226571
7604    231247
6321    285069
5440    257949
Name: icustay_id, Length: 1681, dtype: int64

In [132]:
#### Best model for labs - Under sampled Rocket model with 50 kernels
clf_lab_mm, weight_lab, prob_lab = RocketMulti(labs_grouped,feature_labs, 50, resampling = "Under", filtered_test_ids = mm_test_ids)

UnboundLocalError: local variable 'X_train' referenced before assignment

In [None]:
#### Best model for vitals - Under sampled Rocket model
clf_vital_mm, weight_vital, prob_vital = RocketMulti(vitals_grouped,feature_vitals, resampling = "Over", filtered_test_ids = mm_test_ids)

In [None]:
def soft_voting(clf_static, clf_lab, clf_vital, weight_static, weight_lab, weight_vital,
               prob_static, prob_lab, prob_vital, y_test):
    
    weighted_prob = (weight_static * prob_static +  weight_lab * prob_lab + weight_vital * prob_vital) / np.sum(weight_static,weight_lab, weight_vital)

    y_pred = np.argmax(weighted_prob, axis=1)
    
    metricsReport(y_test,y_pred)

In [None]:
print("weight_static",weight_static)
print("weight_lab",weight_lab)
print("weight_vital",weight_vital)
#print("prob_static",prob_static)
#print("prob_lab",prob_lab)
#print("prob_vital",prob_vital) 

In [None]:
soft_voting(rf_mm, clf_lab_mm, clf_vital_mm, weight_static, weight_lab, weight_vital,
               prob_static, prob_lab, prob_vital, y_test)