# Libraries

In [None]:
# activate line execution
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# general
import numpy as np
import pandas as pd


# import custom libraries
import sys
import os
import tqdm
import pickle
import yaml

In [None]:
os.getcwd()

In [None]:

# plotly
import plotly.express as px  # (version 4.7.0 or higher)
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt


In [None]:
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold

from sklearn.preprocessing import OneHotEncoder

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

# Header

In [None]:
list_datasets = ['p12','p19','mimic','sim-l0.5-d16']

DATASET = 'sim-l0.5-d16'

f = yaml.safe_load(open(f'../configs/data/{DATASET}.yaml'))

PATH_RAW = f['path_raw']
PATH_PROCESSED = f['path_processed']
# create path_processed if it does not exist
if not os.path.exists(PATH_PROCESSED):
    os.makedirs(PATH_PROCESSED)


print(f"DATASET: {DATASET}")
print(f"PATH_RAW: {PATH_RAW}")
print(f"PATH_PROCESSED: {PATH_PROCESSED}")

# This will shuffle the order of clinical time series in the image
SUFFLE_VARS = True


# for TimEHR format
GRAN=1 # granularity of the image (in hours)
IMG_SIZE = f['img_size'] # granularity of the image
N_SPLIT = 5 # number of splits for the cross-validation


In [None]:
int(np.log2(64))

# Functions

In [None]:
def create_df_demo(df_demo, demo_vars):
    from sklearn.preprocessing import OneHotEncoder
    demo_vars

    # print(demo_vars)
    # print(df.columns)

    # df_demo = df[['RecordID','Hospital']+demo_vars].dropna(subset='Age')
    # df_demo = df[['RecordID','Hospital']+demo_vars].dropna()
    # print(df_demo.columns)
    # term
    df_demo

    # replace -1 with NaN and the mean of each column
    df_demo.describe()
    df_demo[df_demo==-1]=np.nan
    df_demo.isnull().sum()
    df_demo.describe()
    for column in df_demo.columns:
        df_demo[column] = df_demo[column].fillna(df_demo[column].mean())
        # print(column)
    df_demo.isnull().sum()
    df_demo.describe()

    

    # standardize continuous variables
    demo_statistics={}
    for column in demo_vars:
        if column in ['Age', 'Height','Weight','HospAdmTime']:
            demo_statistics[column] = {'mean':df_demo[column].mean(),'std':df_demo[column].std()}

            df_demo[column] = (df_demo[column]-df_demo[column].mean())/df_demo[column].std()


    # discritze ICUType
    if 'ICUType' in demo_vars:

        ohe = OneHotEncoder()
        transformed = ohe.fit_transform(df_demo[['ICUType']])
        transformed
        mat_ICU_enc = pd.DataFrame.sparse.from_spmatrix(transformed).values.astype(int)
        new_cols=['ICUType'+str(i) for i in range(mat_ICU_enc.shape[1])]
        df_demo[new_cols]=mat_ICU_enc

        df_demo = df_demo.drop(columns='ICUType')
    
    if 'Gender' in demo_vars:
        df_demo['Gender'] = df_demo['Gender'].astype(int)  
         

    # df_demo.rename(columns={'RecordID':'id'},inplace=True)

    df_demo

    demo_vars_enc = list(df_demo.columns)
    demo_vars_enc.remove('RecordID')
    if 'Hospital' in demo_vars_enc:
        demo_vars_enc.remove('Hospital')
    
    print(demo_vars_enc)
    print(df_demo)
    df_demo['dict_demo'] = df_demo[demo_vars_enc].apply(lambda x:list(x),axis=1)

    df_demo.iloc[0]['dict_demo']


    dict_map_demos = {k:i for i,k in enumerate(demo_vars_enc)}
    dict_map_demos

    return df_demo, dict_map_demos, demo_statistics


In [None]:
def custom_process(df_demo, df_filt, train_ids,state_vars, conditional_vars):

    df_demo_train = df_demo[df_demo['RecordID'].isin(train_ids)].copy()
    df_filt_train = df_filt[df_filt['RecordID'].isin(train_ids)].copy()

    df_demo_test = df_demo[~df_demo['RecordID'].isin(train_ids)].copy()
    df_filt_test = df_filt[~df_filt['RecordID'].isin(train_ids)].copy()

    # if test is empty
    if df_demo_test.shape[0]==0:
        df_demo_test = df_demo_train.copy()
        df_filt_test = df_filt_train.copy()
    print('1',df_filt_train.isnull().sum().sum())

    # Normalize state variables

    # 1] normalized cols in state_vars from df_filt
    print('Step 1: Normalization ',df_filt_train.isnull().sum().sum())

    state_preprocess = {'mean':df_filt_train[state_vars].mean(),'std':df_filt_train[state_vars].std()}

    df_filt_train[state_vars] = (df_filt_train[state_vars]- state_preprocess['mean'])/state_preprocess['std']
    df_filt_test[state_vars] = (df_filt_test[state_vars]- state_preprocess['mean'])/state_preprocess['std']

    # 2] set outliers to nan
    print('Step 2: set outliers to nan ',df_filt_train.isnull().sum().sum())
    if DATASET not in ['energy','stock']:
        df_filt_train[state_vars] = df_filt_train[state_vars].apply(lambda x: x.mask(x.sub(x.mean()).div(x.std()).abs().gt(3)))
        df_filt_test[state_vars] = df_filt_test[state_vars].apply(lambda x: x.mask(x.sub(x.mean()).div(x.std()).abs().gt(3)))
        # print('step 2',df_filt_train.isnull().sum().sum())
        
    # 3] now do min-max normalization # between 0 and 1
    print('Step 3: min-max normalization ')
    state_preprocess['min'] = df_filt_train[state_vars].min()
    state_preprocess['max'] = df_filt_train[state_vars].max()

    df_filt_train[state_vars] = (df_filt_train[state_vars]-state_preprocess['min'])/(state_preprocess['max']-state_preprocess['min'])
    df_filt_test[state_vars] = (df_filt_test[state_vars]-state_preprocess['min'])/(state_preprocess['max']-state_preprocess['min'])

    # # # 4] scale to [-1,1]
    # print('Step 4: scale to [-1,1] ')
    # df_filt_train[state_vars] = df_filt_train[state_vars]*2-1
    # df_filt_test[state_vars] = df_filt_test[state_vars]*2-1


    # Normalize demo variables
    print('Step 5: Normalize demo variables ')
    df_demo_train, demo_dict,demo_preprocess = create_df_demo(df_demo_train, conditional_vars)

    demo_preprocess['demo_vars_enc'] = list(demo_dict.keys())
    df_demo_test, _,_ = create_df_demo(df_demo_test, conditional_vars)
    
    # print(list(demo_preprocess['demo_vars_enc']))
    return df_demo_train, df_demo_test, df_filt_train, df_filt_test, state_preprocess, demo_preprocess

In [None]:
def handle_cgan(df_demo, df_filt, state_vars,demo_vars_enc,granularity=1,IMG_SIZE=64):
   # var_old = [ 'Albumin', 'ALP', 'ALT', 'AST', 'Bilirubin', 'BUN',
   #    'Cholesterol', 'Creatinine', 'DiasABP', 'FiO2', 'GCS', 'Glucose',
   #    'HCO3', 'HCT', 'HR', 'K', 'Lactate', 'Mg', 'MAP', 'MechVent', 'Na',
   #    'NIDiasABP', 'NIMAP', 'NISysABP', 'PaCO2', 'PaO2', 'pH', 'Platelets',
   #    'RespRate', 'SaO2', 'SysABP', 'Temp', 'TroponinI', 'TroponinT', 'Urine',
   #    'WBC']

   # print('set(state_vars)-set(var_old) => ',set(state_vars)-set(var_old))
   # print('set(var_old)-set(state_vars) => ',set(var_old)-set(state_vars))


   # df_demo, demo_dict = create_df_demo(df_demo, demo_vars)
   # print(demo_dict)
   # print(df_demo.columns)
   print(demo_vars_enc)
   sta = df_demo[demo_vars_enc].rename(columns={'In-hospital_death':'Label'})
   # sta['Label'] = sta['Label'].astype(float)
   # sta["Gender"] = sta["Gender"].astype(int)
   # sta["Label"] = sta["Label"].astype(int)

   # print(sta)
   all_sta = torch.from_numpy(sta.values)
   print(all_sta.shape)



   dyn=[]

   grouped = df_filt.groupby('RecordID')

   # Create a list of DataFrames
   dyn = [group[   ['Time']+state_vars].rename(columns={'Time':'time'}) for _, group in grouped]

   all_times = []
   all_dyn = []
   all_masks = []
   all_time_mark = []
   for SAMPLE in dyn:
      time = SAMPLE.time.values
      # times_padded = pd.DataFrame({'time':np.arange(0,   int(max(time)),granularity)})

      # print(time,len(time),max(time))
      # print(times_padded.values.flatten(),len(times_padded.values.flatten()))
      
      # create a np array from 0 to max(time) with granularity 0.5 ()including max(time)) use linspace
      time_padded = np.linspace(0, int(IMG_SIZE/granularity)-granularity, IMG_SIZE) # shape is (IMG_SIZE,)
      time_mark = time_padded<=max(time)
      # print(time_mark,time_mark.shape)
      # term
      df_time_padded = pd.DataFrame({'time':time_padded})
      temp = df_time_padded.merge(SAMPLE,how='outer',on='time')#.sort_values(by='time_padded').
      
      dyn_padded = temp[state_vars].fillna(0).values
      mask_padded = temp[state_vars].notnull().astype(int).values


      all_times.append(torch.from_numpy(time_padded))
      all_dyn.append(torch.from_numpy(dyn_padded))
      all_masks.append(torch.from_numpy(mask_padded))
      all_time_mark.append(torch.from_numpy(time_mark))
      # print(time_padded.shape,dyn_padded.shape,mask_padded.shape)
   
      
   all_masks = torch.stack(all_masks, dim=0)
   all_dyn = torch.stack(all_dyn, dim=0)
   all_times = torch.stack(all_times, dim=0)
   all_time_mark =torch.stack(all_time_mark)
   # PADDING
   # IMG_SIZE = IMG_SIZE
   padding_needed = IMG_SIZE - all_masks.shape[-1]

   all_masks_padded = torch.nn.functional.pad(all_masks, (0, padding_needed)).unsqueeze(1).float() # add channel dim
   all_dyn_padded = torch.nn.functional.pad(all_dyn, (0, padding_needed)).unsqueeze(1).float() # add channel dim
   all_times_padded = torch.nn.functional.pad(all_times, (0, padding_needed))
   all_masks_padded.shape
   all_dyn_padded.shape




   all_data = torch.stack([all_dyn_padded, all_masks_padded], dim=1)
   all_data.shape
   
   # # SCALE all_data to [-1,1]
   all_masks_padded = all_masks_padded*2-1
   all_dyn_padded = all_dyn_padded*2-1
   all_dyn_padded[all_masks_padded<0]=0
   # all_data = all_data*2-1
   # all_data[:,0,:,:][all_data[:,1,:,:]<0]=0


   
   return all_masks_padded.float(),all_dyn_padded.float(),all_sta.float(), all_time_mark.int()

# Processed Format

In [None]:

if DATASET=='p12':
    print('### Handling P12 dataset')
    df = pd.read_csv(PATH_RAW+f'/{DATASET}.csv', index_col=None)
    
    print('# Number of patients:',df.RecordID.nunique())

    # convert time to hours
    df['Time'] = df['Time'].apply(lambda x:       int(x.split(':')[0]) + int(x.split(':')[1])/60      )

    # rename column
    df.rename(columns={'In-hospital_death':'Label'}, inplace=True)
    
    print('# Prevelence of in-hospital mortality: ', df.Label.mean())

    # columns handling
    all_cols = df.columns

    cols_id = ['RecordID','Time']
    cols_outcome = ['Label']
    cols_demo = ['Age','Gender','Height','ICUType','Weight']
    cols_vital = ['HR','NIDiasABP', 'NIMAP', 'NISysABP','RespRate', 'Temp','DiasABP','MAP','SysABP', 'GCS']
    cols_lab = [ 'BUN', 'Creatinine', 'Glucose', 'HCO3', 'HCT', 'K', 'Mg','Na', 'Platelets','Urine', 'WBC','FiO2', 'PaCO2', 'PaO2', 'SaO2', 'pH', 'ALP', 'ALT',
        'AST', 'Albumin', 'Bilirubin', 'Lactate', 'Cholesterol', 'TroponinI',
        'TroponinT']
    # cols_ignore = ['MechVent','Hospital'] # will be ignored

    print('# Distribution of LOS:')
    print(df.groupby('RecordID')['Time'].max().describe())


    
elif DATASET in ['p19']:
    print('### Handling P19 dataset')


    # data from hispital A and B
    df_A = pd.read_csv(PATH_RAW+'/df_A.csv')#.merge(df_filenames_A,on='id')
    df_A['Hospital']=0

    df_B = pd.read_csv(PATH_RAW+'/df_B.csv')#.merge(df_filenames_B,on='id')
    df_B['Hospital']=1

    # combine two datasets into one
    df_B['id']=df_B['id']+len(df_A['id'].unique())
    df=pd.concat([df_A,df_B])

    print('# Number of patients:',df.id.nunique())
    
    pos_ids = df[df.SepsisLabel==1].id.unique()    

    print('# Prevelence of sepsis:',len(pos_ids)/df.id.nunique())
    # IMPORTANT
    # we shift the label 6 hours before. This is because the label at each row indicates sepsis for the next 6 hours
    df['SepsisLabel'] = df.groupby('id')['SepsisLabel'].shift(6, fill_value=0).astype(int)

    # we remove the rows with label 1. Our goal is to predict sepsis patients, not the detection of onset of sepsis   
    df = df[df.SepsisLabel!=1]

    # if patient is septic, we set the label to 1 across all rows
    df.loc[df.id.isin(pos_ids),'SepsisLabel'] = 1
    df.loc[~df.id.isin(pos_ids),'SepsisLabel'] = 0
    

    # rename some columns
    df.rename(columns={'id':'RecordID','ICULOS':'Time','SepsisLabel':'Label'}, inplace=True)

    print('# Distribution of LOS: brefore removing patients with LOS>64')
    print('# df.shape:',df.shape)
    print(df.groupby('RecordID')['Time'].max().describe())
    # df['Time'].hist(bins=100)
    
    # we only keep first 64 hours
    df = df[df.Time<=63]

    print('# Distribution of LOS: after removing patients with LOS>64')
    print('# df.shape:',df.shape)
    print(df.groupby('RecordID')['Time'].max().describe())
    # df['Time'].hist(bins=100)

    
    # columns handling
    all_cols = df.columns

    cols_id = ['RecordID','Time']
    cols_outcome = ['Label']
    cols_demo = ['Age','Gender','HospAdmTime']
    cols_vital = [ 'HR', 'O2Sat', 'Temp', 'SBP', 'MAP', 'DBP', 'Resp'] 
    cols_lab = [ 
        'BUN','Creatinine', 'Glucose',  'HCO3', 'Potassium','Magnesium',
        'Hct','Platelets','WBC',
        'FiO2', 'PaCO2','pH','SaO2',        
         'Alkalinephos','AST', 'Bilirubin_total','Bilirubin_direct',
         'Lactate',  'TroponinI',
        'Hgb', 'Chloride', 
          'Phosphate', 'Calcium',    'PTT',   'Fibrinogen' ]
    # cols_ignore = ['EtCO2','BaseExcess','Unit1', 'Unit2','Hospital'] # will be ignored

elif DATASET in ['mimic']:

    df = pd.read_csv(PATH_RAW+'/df_mimic.csv', index_col=None)

    # columns handling
    all_cols = df.columns
 
    cols_id = ['RecordID','Time']
    cols_outcome = ['Label']
    cols_demo = ['Age','Height','Weight']
    cols_vital = ['heartrate', 'sysbp', 'diasbp', 'meanbp',
        'resprate', 'tempc', 'spo2']
    cols_lab = [ 'glucose_chart',  'endotrachflag', 'bg_so2', 'bg_po2', 'bg_pco2',
        'bg_pao2fio2ratio', 'bg_ph', 'bg_baseexcess', 'bg_bicarbonate',
        'bg_totalco2', 'bg_hematocrit', 'bg_hemoglobin', 'bg_carboxyhemoglobin',
        'bg_methemoglobin', 'bg_chloride', 'bg_calcium', 'bg_temperature',
        'bg_potassium', 'bg_sodium', 'bg_lactate', 'bg_glucose', 'aniongap',
        'albumin', 'bands', 'bicarbonate', 'bilirubin', 'creatinine',
        'chloride', 'glucose', 'hematocrit', 'hemoglobin', 'lactate',
        'platelet', 'potassium', 'ptt', 'inr', 'pt', 'sodium', 'bun', 'wbc',
        'urineoutput']
    # cols_ignore = []

elif 'sim' in DATASET:
    
    df = pd.read_csv(PATH_RAW+f'/{DATASET}.csv')

    if '16' in DATASET:
        seq_len = 16
    elif '32' in DATASET:
        seq_len = 32
    elif '64' in DATASET:
        seq_len = 64
    elif '128' in DATASET:
        seq_len = 128
    # seq_len = 128 if '128' in DATASET else 64

    

    # MODE 1: disjoint

    a = len(df)//seq_len*seq_len
    df = df.iloc[:a]
    # simplify above

    var_names = df.columns.tolist()

    df['Time'] =  np.tile(np.arange(seq_len),len(df)//seq_len)
    df['RecordID'] = np.repeat(np.arange(len(df)//seq_len),seq_len)
    
    
    
    # # MODE 2: overlapping
    # mat = df.values
    # # Preprocess the dataset
    # temp_data = []
    # id_data = []
    # time_data = []
    # # Cut data by sequence length
    # for i in range(0, len(mat) - seq_len):
    #     _x = mat[i:i + seq_len]
    #     temp_data.append(_x)
    #     id_data.append(np.array(np.ones(seq_len)*i))
    #     time_data.append(np.arange(seq_len))
    # len(temp_data),temp_data[0].shape

    # mat2 = np.concatenate(temp_data)
    # mat2.shape

    # id_data = np.concatenate(id_data)
    # id_data.shape
    
    # time_data = np.concatenate(time_data)
    # time_data.shape
    #     #form a dataframe
    # df = pd.DataFrame(mat2, columns=df.columns)
    # df['RecordID'] = id_data
    # df['Time'] = time_data
    
    



    # continue
    
    df['Hospital'] = 0
    # set age randomly
    df['Age'] = np.random.randint(20,80,len(df))     
    df['Label'] = 0
    all_cols = df.columns

    cols_id = ['RecordID','Hospital','Time']
    cols_outcome = ['Label']
    cols_demo = ['Age']
    cols_vital = var_names
    cols_lab = []
    cols_ignore = []


In [None]:
remaining_cols = list(set(all_cols) - set(cols_id) - set(cols_outcome) - set(cols_demo) - set(cols_vital) - set(cols_lab) )


print("remaining_cols: ", remaining_cols)


In [None]:
state_vars = cols_vital+cols_lab

if SUFFLE_VARS:
    np.random.seed(42)
    np.random.shuffle(state_vars)


demo_vars = cols_demo

dict_map_states = {label:i for i,label in enumerate(state_vars)}

dict_map_demos = {k:i for i,k in enumerate(cols_demo)}
dict_map_demos


print("number of state variables (vital+lab): ", len(state_vars))
print("number of demographic variables: ", len(demo_vars))

In [None]:
# split df to df_ts and df_static

df['RecordID'] = df['RecordID'].astype(int)
df['Label'] = df['Label'].astype(int)

df_ts = df[['RecordID','Time']+cols_vital+cols_lab].copy()
df_static = df[['RecordID']+cols_demo+cols_outcome].drop_duplicates(subset='RecordID').copy()

df_ts.shape, df_static.shape

df_ts.head()
df_static.head()

In [None]:
# check for missing rate

df_ts.isnull().sum()/(df_ts.shape[0]*df_ts.shape[1])*100

df_static.isnull().sum()/(df_static.shape[0]*df_static.shape[1])*100

print("overall missingness rate(%): ", df_ts.isnull().sum().sum()/(df_ts.shape[0]*df_ts.shape[1])*100)

# bar plot of missing rate for each variable using plotly

missing_rate = df_ts.isnull().mean().sort_values(ascending=False)
missing_rate = missing_rate[missing_rate>0]

fig = px.bar(x=missing_rate.index, y=missing_rate.values*100, labels={'x':'Variable','y':'Missing rate (%)'}, title='Missing rate for each variable')
fig.show()

In [None]:
# choose time granularity
time_granularity = 1  # 1 hour

df_ts.iloc[:30].Time.values
df_ts['Time'] = df_ts['Time'].apply(lambda x: round(x/time_granularity,0)*time_granularity)
df_ts.iloc[:30].Time.values
df_ts.shape
df_ts.head(10)

In [None]:
# drop rows if all time series variables are missing
df_ts.shape

df_ts = df_ts.dropna(subset=cols_vital+cols_lab, how='all')

df_ts.shape

In [None]:
df_ts.head(10)

df_ts.describe()

# You can see that in P12 we might have multiple measurements for the same time point.


## Aggregate

if multiple observation then ffill and keep last

It is only used for P12 dataset. Time granularity for other datasets is already 1 hour.

In [None]:
# is aggregation needed?

a1 = df_ts.groupby('RecordID').size().values
a2 = df_ts.groupby('RecordID')['Time'].nunique().values

if (a1-a2).sum()>0:
    print('df.shape before aggregation:',df_ts.shape)
    # forwardfill for each group

    df_ts[cols_lab+cols_vital] = df_ts[['RecordID','Time']+cols_lab+cols_vital].groupby(['RecordID','Time']).fillna(method='ffill')

    # keep last for each group
    df_ts = df_ts[['RecordID','Time']+cols_lab+cols_vital].groupby(['RecordID','Time']).last().reset_index()
    print('df.shape after aggregation:',df_ts.shape)
else:
    print('no aggregation needed')

In [None]:
df_static.shape

# forwarfill for each group
df_static[cols_demo+cols_outcome] = df_static[['RecordID']+cols_demo+cols_outcome].groupby(['RecordID']).fillna(method='ffill')

# keep last for each group
df_static = df_static[['RecordID']+cols_demo+cols_outcome].groupby(['RecordID']).last().reset_index()

df_static.shape

In [None]:
df_ts.RecordID.nunique(), df_static.RecordID.nunique()

In [None]:
df_ts.Time.describe()

## K-fold

this will write df_ts and df_static for each split.

In [None]:
N_SPLIT=5 # number of folds

# seed for numpy and dataframe sampling
np.random.seed(42)


# shuffle ids
list_ids = df_static['RecordID'].values.copy()
np.random.shuffle(list_ids)
list_ids[:5]

df_static = df_static[df_static['RecordID'].isin(list_ids)]
df_static.head()



split_list = np.linspace(0,len(df_static),5+1).astype(int)
split_list




for i_split in range(N_SPLIT):

    path2save = PATH_PROCESSED+f'/split{i_split}'
    os.makedirs(path2save , exist_ok=True)
    
    test_ids = list_ids[split_list[i_split]:split_list[i_split+1]]
    i_split, sum(test_ids)
    train_ids = list(set(list_ids)-set(test_ids))

        
    # save train ids
    with open(path2save+'/train_ids.pkl', 'wb') as f:
        pickle.dump(train_ids, f)

 

    # save df_static and df_ts
    df_static.to_csv(path2save+'/df_static.csv',index=False)
    df_ts.to_csv(path2save+'/df_ts.csv',index=False)
    
    print('saved to', path2save)
    



# 163575645

# TimEHR format

We further process df_ts and df_static to images for TimEHR.

In [None]:
# Dataset class

class Physio3(Dataset):
    def __init__(self, all_masks,all_values,all_sta,static_processor=None, dynamic_processor=None, transform=None, ids=None, max_len=None):
        self.num_samples = all_masks.shape[0]
        self.mask = all_masks
        self.value = all_values
        self.sta = all_sta
        self.transform = transform

        self.static_processor = static_processor
        self.dynamic_processor = dynamic_processor
        self.ids = ids
        self.max_len = max_len

        self.n_ts = len(self.dynamic_processor['mean'])
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):

        return self.mask[idx], self.value[idx], self.sta[idx]

In [None]:

conditional_vars = cols_demo+cols_outcome
# These variables are used for the conditional GAN framework


In [None]:


for i_split in range(N_SPLIT):
    
    
    path2save = PATH_PROCESSED+f'/split{i_split}'
    
    print("path2save:", path2save)

    # loading processed data
    df_static = pd.read_csv(path2save+'/df_static.csv', index_col=None)    
    df_ts = pd.read_csv(path2save+'/df_ts.csv', index_col=None)
    df_static.shape, df_ts.shape

    
    # merge df_static and df_ts based on RecordID    
    df_ts = df_ts.merge(df_static[['RecordID','Label']],on=['RecordID'],how='inner') 
    df_static = df_static.merge(df_ts[['RecordID']].drop_duplicates(),on=['RecordID'],how='inner') 
    
    # sort both
    df_static = df_static.sort_values(by=['RecordID'])
    df_ts = df_ts.sort_values(by=['RecordID','Time'])

    
    df_static.shape, df_ts.shape
    

    
    # load train_ids
    with open(path2save+'/train_ids.pkl', 'rb') as f:
        train_ids = pickle.load(f)
        len(train_ids)

    # load dev_ids
    # with open(path2save+'dev_ids.pkl', 'rb') as f:
    #     dev_ids = pickle.load(f)
    dev_ids=[-5555]
    len(dev_ids)
    
    # exclude dev_ids from train_ids
    df_static = df_static[~df_static['RecordID'].isin(dev_ids)].copy()
    df_ts = df_ts[~df_ts['RecordID'].isin(dev_ids)].copy()
    
    df_static.shape, df_ts.shape

    
    
    df_static_train, df_static_test, df_ts_train, df_ts_test,   state_preprocess, demo_preprocess  = custom_process(df_static, df_ts, train_ids, state_vars, conditional_vars)
    
    
    

    

    df_static_train.shape, df_static_test.shape
    df_ts_train.shape, df_ts_test.shape

    
    
    
    demo_vars_enc = demo_preprocess['demo_vars_enc']

    # Train dataset
    all_masks,all_values,all_sta,all_time_mark = handle_cgan(df_static_train, df_ts_train,state_vars,demo_vars_enc,granularity=GRAN,
    IMG_SIZE=IMG_SIZE)

    
    
    ph = Physio3(all_masks,all_values,all_sta,static_processor=demo_preprocess, dynamic_processor=state_preprocess, ids=df_static_train['RecordID'].values, max_len=all_time_mark)  

    with open(path2save+f"/train.pkl", 'wb') as file:
        pickle.dump(ph, file)       

    # Val dataset
    all_masks,all_values,all_sta,all_time_mark = handle_cgan(df_static_test, df_ts_test,state_vars,demo_vars_enc,granularity=GRAN,IMG_SIZE=IMG_SIZE)
    temp2 = all_sta.clone()

    
    
    ph = Physio3(all_masks,all_values,all_sta,static_processor=demo_preprocess,
     dynamic_processor=state_preprocess, ids=df_static_test['RecordID'].values, max_len=all_time_mark)    

    with open(path2save+f"/eval.pkl", 'wb') as file:
        pickle.dump(ph, file) 

    # cgan3 now pixels are from -1 to 1
