# Comprehensive Exploratory Data Analysis: Step by Step

<a name='1'></a>
## 1 - Packages

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import h5py

import os
# import wandb
import warnings

from config import *
from utils.eda_functions import *
from src.data_preprocessing.demographics import *
from src.data_preprocessing.vitals_labs import *
from src.data_preprocessing.split_dataset import *
from src.data_imputation.simple_impute import *
from src.data_imputation.hybrid_impute import hybrid_imputer
from utils.safe_display import blind_display

%load_ext autoreload
%autoreload 2

In [2]:
# Configuration & Settings:

# plt.style.use('seaborn-whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

warnings.filterwarnings('ignore')
# os.environ['WANDB_SILENT'] = 'true'
# sns.set_theme(style="whitegrid")
# 
# # Login and initialize a new wandb run
# wandb_key = os.environ.get("WANDB_API_KEY")
# ! wandb login $wandb_key
#
# run = wandb.init(
#     project='FuzzyMedNet',
#     name='patient_eda',
#     job_type='eda'
# )

In [3]:
# Safety flag - when set to True will not display sensitive data
BLINDED = False

<a name='2'></a>
## 2 - Data Overview

<a name='2-1'></a>
### 2.1 - Loading the Data

In [4]:
# Open the file in read mode
with h5py.File(DATA_FILE_PATH, 'r') as file:
    # Print the keys at the root of the file
    print(list(file.keys()))

patients = pd.read_hdf(DATA_FILE_PATH, 'patients')
vitals_labs = pd.read_hdf(DATA_FILE_PATH, 'vitals_labs')
vitals_labs_mean = pd.read_hdf(DATA_FILE_PATH, 'vitals_labs_mean')
interventions = pd.read_hdf(DATA_FILE_PATH, 'interventions')
codes = pd.read_hdf(DATA_FILE_PATH, 'codes')

['codes', 'interventions', 'patients', 'vitals_labs', 'vitals_labs_mean']


<a name='2-2'></a>
### 2.2 - Basic Information

In [5]:
# print(f'patients.shape: {patients.shape}')
# print(patients.info())

In [6]:
# print(f'vitals_labs_mean.shape: {vitals_labs_mean.shape}')
# print(vitals_labs_mean.info())

In [7]:
# print(f'interventions.shape: {interventions.shape}')
# print(interventions.info())

In [8]:
# # Limit vital signs, lab measurements, and interventions to the first 30 hours of admission
# vitals_labs = vitals_labs[vitals_labs.index.get_level_values('hours_in') < 30]
# vitals_labs_mean = vitals_labs_mean[vitals_labs_mean.index.get_level_values('hours_in') < 30]
# interventions = interventions[interventions.index.get_level_values('hours_in') < 30]
# blind_display(patients, vitals_labs_mean, interventions, blinded=BLINDED)

<a name='3'></a>
## 3 - Univariate Analysis

<a name='3-1'></a>
### 3.1 - Categorical Variables

<a name='3-1-1'></a>
`patients` data

<a name='3-1-1-1'></a>
#### Transforming Static Features

In [9]:
# group age
patients['age'] = patients['age'].apply(categorize_age)
patients['age'].value_counts(dropna=False)

age
>70      14213
51-70    12938
31-50     5489
<31       1832
Name: count, dtype: int64

In [10]:
# check ethnicity categories
patients['ethnicity'].value_counts(dropna=False)

ethnicity
WHITE                                                       24429
UNKNOWN/NOT SPECIFIED                                        3221
BLACK/AFRICAN AMERICAN                                       2456
HISPANIC OR LATINO                                            881
OTHER                                                         785
UNABLE TO OBTAIN                                              652
ASIAN                                                         545
PATIENT DECLINED TO ANSWER                                    351
ASIAN - CHINESE                                               166
HISPANIC/LATINO - PUERTO RICAN                                124
BLACK/CAPE VERDEAN                                            122
WHITE - RUSSIAN                                                99
MULTI RACE ETHNICITY                                           77
BLACK/HAITIAN                                                  64
WHITE - OTHER EUROPEAN                                         59


In [11]:
# Regroup similar categories of ethnicities
patients['ethnicity'] = patients['ethnicity'].apply(categorize_ethnicity)
patients['ethnicity'].value_counts(dropna=False)

ethnicity
WHITE              24675
OTHER/UNKNOWN       5086
BLACK               2667
HISPANIC            1137
ASIAN                865
ISLANDER              25
NATIVE AMERICAN       17
Name: count, dtype: int64

In [12]:
# regroup admission types into: EMERGENCY & ELECTIVE
patients['admission_type'] = patients['admission_type'].apply(group_admission_type)
patients['admission_type'].value_counts(dropna=False)

admission_type
EMERGENCY    28767
ELECTIVE      5705
Name: count, dtype: int64

<a name='3-1-2'></a>
`interventions` data

<a name='3-2'></a>
### 3.2 - Continuous Variables

<a name='3-2-1'></a>
`patients` data

In [13]:
# # Calculate length of stay in ICU
# patients['icu_stay_length'] = calculate_duration(patients, 'intime', 'outtime', 'h')  # in hours
#
# # Plot distribution of times
# time_columns = ['admittime', 'dischtime', 'intime', 'outtime']
# plot_time_analysis(patients, time_columns)

<a name='3-2-2'></a>
`vitals_labs` data
1. **Descriptive Statistics**
We'll start by calculating summary statistics for the DataFrame. This will help us understand the central tendency, spread, and shape of the distribution of the dataset.
<br></br>
2. **Time-Series Plots**
For selected vital signs and lab measurements, we'll plot time-series graphs.
<br></br>
3. **Distribution Plots**
We'll visualize the distribution of selected columns to understand their shape, center, and spread.

#### Hourly Measurements Spread across first 30 hours - Vital Signs

In [14]:
# plot_hourly_counts(df=hourly_vitals_df, features=VITAL_SIGNS)

#### Hourly Measurements Spread across first 30 hours - Labs

In [15]:
# plot_hourly_counts(df=hourly_vitals_df, features=lab_features)

<a name='4'></a>
## 4 - Data Pre-processing

<a name='4-1'></a>
### 4.1 - Identifying Missing Data

In the preceding examination, it was identified that missing data exists solely within the `vitals_labs` dataframe.

<a name='4-2'></a>
### 4.2 - Train/Dev/Test Split

When imputing missing values, it's a best practice to calculate the mean (or any other statistic you're using for imputation) from the training set only. Then, use this calculated mean to impute missing values in both the training set and the dev/test sets.
This approach helps in preventing data leakage from the dev/test set into the training set.

The function `train_test_dev_split` performs the data splitting operation for the project. It takes three dataframes (`patients`, `vitals_labs`, and `interventions`) and returns a dictionary containing the train, dev, and test splits for each dataframe.

Steps:
**Filter Targets**: Only consider patients with sufficient data for the target variables `mort_hosp` and `mort_icu`.

**Extract Static Features**: Remove the target variables and other non-static features from the `patients` dataframe.

**Filter Time-Series Data**: Trim the time-series data to only include instances within a specified observation window.

**Subject ID Validation**: Check that the subject IDs are consistent across all dataframes.

**Random Split**: Randomly shuffle the subject IDs and allocate them into train, dev, and test sets based on predefined fractions.

**Dataframe Splits**: Create train, dev, and test dataframes for each of the input dataframes (`patients`, `vitals_labs`, `interventions`), based on the shuffled subject IDs.

**Return Data**: A dictionary containing the train, dev, and test splits for each dataframe is returned.

### Encoding full `patient` categorical data before splitting

In [16]:
# Drop irrelevant or duplicated columns
columns_to_drop = ['admittime', 'dischtime', 'outtime', 'deathtime', 'discharge_location', 'dnr_first_charttime', 'diagnosis_at_admission', 'insurance', 'hospstay_seq', 'hospital_expire_flag', 'los_icu']

patients.drop(columns= columns_to_drop, inplace=True)

# Select columns to be one-hot encoded
categorical_cols = ['gender', 'age', 'ethnicity', 'admission_type', 'first_careunit']

patients_encoded = pd.get_dummies(patients, columns=categorical_cols)
patients_encoded.drop(columns=['gender_F', 'admission_type_ELECTIVE'], inplace=True)
patients_encoded.fillna(0, inplace=True)

patients_encoded.loc[:, 'intime'] = patients_encoded['intime'].astype('datetime64[ns]').apply(lambda x : x.hour)

blind_display(patients_encoded, blinded=BLINDED)

patients_encoded.shape:  (34472, 30)


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,fullcode_first,dnr_first,fullcode,dnr,cmo_first,cmo_last,cmo,intime,mort_icu,mort_hosp,...,ethnicity_ISLANDER,ethnicity_NATIVE AMERICAN,ethnicity_OTHER/UNKNOWN,ethnicity_WHITE,admission_type_EMERGENCY,first_careunit_CCU,first_careunit_CSRU,first_careunit_MICU,first_careunit_SICU,first_careunit_TSICU
subject_id,hadm_id,icustay_id,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
3,145834,211552,1.0,0.0,1.0,1.0,0.0,0.0,0.0,19,0,0,...,False,False,False,True,True,False,False,True,False,False
4,185777,294638,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0,0,0,...,False,False,False,True,True,False,False,True,False,False
6,107064,228232,1.0,0.0,1.0,0.0,0.0,0.0,0.0,21,0,0,...,False,False,False,True,False,False,False,False,True,False
9,150750,220597,1.0,0.0,1.0,0.0,0.0,0.0,0.0,13,1,1,...,False,False,True,False,True,False,False,True,False,False
11,194540,229441,1.0,0.0,1.0,0.0,0.0,0.0,0.0,6,0,0,...,False,False,False,True,True,False,False,False,True,False


In [17]:
datasets = train_test_dev_split(patients_encoded, vitals_labs, interventions)

In [18]:
# Define keys and corresponding variable names
keys_varnames = [
    ('patients', 'patients_train', 'patients_dev', 'patients_test'),
    ('vitals', 'vitals_train', 'vitals_dev', 'vitals_test'),
    ('interv', 'interv_train', 'interv_dev', 'interv_test'),
    ('Ys', 'Ys_train', 'Ys_dev', 'Ys_test')
]

# Loop through each key and variable name to extract data and print shapes
for key, train_var, dev_var, test_var in keys_varnames:
    train_data, dev_data, test_data = datasets[key]
    print(f'\n{key.capitalize()}:')
    print(train_data.shape)
    print(dev_data.shape)
    print(test_data.shape)

    # log the datasets
    # save_to_pickle(train_data, os.path.join(LOG_DATA_DIR, f'{train_var}_gru_split.pkl'))
    # save_to_pickle(dev_data, os.path.join(LOG_DATA_DIR, f'{dev_var}_gru_split.pkl'))
    # save_to_pickle(test_data, os.path.join(LOG_DATA_DIR, f'{test_var}_gru_split.pkl'))

    # set the variables
    globals()[train_var], globals()[dev_var], globals()[test_var] = train_data, dev_data, test_data


Patients:
(11543, 27)
(1650, 27)
(3299, 27)

Vitals:
(554064, 312)
(79200, 312)
(158352, 312)

Interv:
(554064, 14)
(79200, 14)
(158352, 14)

Ys:
(11543, 2)
(1650, 2)
(3299, 2)


<a name='4-3'></a>
### 4.3 Data Imputation
The strategy for handling the missing values in `vitals_labs` is as follows:

- Initial imputation employs **forward-filling** to propagate the last valid observation to succeeding `NaN` entries.

- Subsequently, any residual missing values within each `icustay_id` group are replaced by the **group's mean**.

- Finally, any remaining `NaN`s are filled with **zeros**.

- A binary mask is generated to indicate the presence of data.

- Time deltas are computed to represent the duration since the last available measurement for each column.

#### Simple Impute

In [30]:
from src.data_imputation.simple_impute import simple_imputer

In [31]:
%%time
global_means, icustay_means = calculate_impute_values(vitals_train)

vitals_train_imputed, vitals_dev_imputed, vitals_test_imputed = [
    simple_imputer(df, global_means, icustay_means) for df in (vitals_train, vitals_dev, vitals_test)
]
# vitals_train_flat, vitals_dev_flat, vitals_test_flat = [
#     df.pivot_table(index=['subject_id', 'hadm_id', 'icustay_id'], columns=['hours_in']) for df in (
#         vitals_train_imputed, vitals_dev_imputed, vitals_test_imputed
#     )
# ]

# # log the datasets
# save_to_pickle(vitals_train_imputed, os.path.join(LOG_DATA_DIR, f'vitals_train_gru_imputed.pkl'))
# save_to_pickle(vitals_dev_imputed, os.path.join(LOG_DATA_DIR, f'vitals_dev_gru_imputed.pkl'))
# save_to_pickle(vitals_test_imputed, os.path.join(LOG_DATA_DIR, f'vitals_test_gru_imputed.pkl'))

for df in vitals_train_imputed, vitals_dev_imputed, vitals_test_imputed: assert not df.isnull().any().any()

AssertionError: 

In [32]:
vitals_train_imputed

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,LEVEL2,alanine aminotransferase,alanine aminotransferase,alanine aminotransferase,albumin,albumin,albumin,albumin ascites,albumin ascites,albumin ascites,albumin pleural,...,venous pvo2,weight,weight,weight,white blood cell count,white blood cell count,white blood cell count,white blood cell count urine,white blood cell count urine,white blood cell count urine
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Aggregation Function,mask,mean,time_since_measured,mask,mean,time_since_measured,mask,mean,time_since_measured,mask,...,time_since_measured,mask,mean,time_since_measured,mask,mean,time_since_measured,mask,mean,time_since_measured
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2
9,150750,220597,0,0.0,,100.0,0.0,,100.0,0.0,,100.0,0.0,...,100.0,0.0,,100.0,1.0,7.5,0.0,0.0,,100.0
9,150750,220597,1,0.0,,100.0,0.0,,100.0,0.0,,100.0,0.0,...,100.0,0.0,,100.0,0.0,,1.0,0.0,,100.0
9,150750,220597,2,0.0,,100.0,0.0,,100.0,0.0,,100.0,0.0,...,100.0,0.0,,100.0,0.0,,2.0,0.0,,100.0
9,150750,220597,3,0.0,,100.0,0.0,,100.0,0.0,,100.0,0.0,...,100.0,0.0,,100.0,0.0,,3.0,0.0,,100.0
9,150750,220597,4,0.0,,100.0,0.0,,100.0,0.0,,100.0,0.0,...,100.0,0.0,,100.0,0.0,,4.0,0.0,,100.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,137810,229633,43,0.0,,41.0,0.0,,267.0,0.0,,7821.0,0.0,...,214679.0,0.0,,4.0,0.0,,7.0,0.0,,11.0
99995,137810,229633,44,0.0,,42.0,0.0,,268.0,0.0,,7822.0,0.0,...,214680.0,1.0,67.857419,0.0,0.0,,8.0,0.0,,12.0
99995,137810,229633,45,0.0,,43.0,0.0,,269.0,0.0,,7823.0,0.0,...,214681.0,0.0,,1.0,0.0,,9.0,0.0,,13.0
99995,137810,229633,46,0.0,,44.0,0.0,,270.0,0.0,,7824.0,0.0,...,214682.0,0.0,,2.0,0.0,,10.0,0.0,,14.0


In [20]:
# vitals_train_imputed = load_from_pickle(os.path.join(DATA_DIR, 'vitals_train.pkl'))
# vitals_dev_imputed = load_from_pickle(os.path.join(DATA_DIR, 'vitals_dev.pkl'))
# vitals_test_imputed = load_from_pickle(os.path.join(DATA_DIR, 'vitals_test.pkl'))

#### Hybrid Impute

In [21]:
# %%time
# global_means, icustay_means = calculate_impute_values(vitals_train)
#
# vitals_train_imputed2, vitals_dev_imputed2, vitals_test_imputed2 = [
#     hybrid_imputer(df, global_means, icustay_means) for df in (vitals_train, vitals_dev, vitals_test)
# ]
# for df in vitals_train_imputed2, vitals_dev_imputed2, vitals_test_imputed2: assert not df.isnull().any().any()

In [22]:
# # Define keys and corresponding variable names
# vitals_train_imputed2.to_csv('../data/processed_hybrid/vitals_train_hybrid.csv')
# vitals_dev_imputed2.to_csv('../data/processed_hybrid/vitals_dev_hybrid.csv')
# vitals_test_imputed2.to_csv('../data/processed_hybrid/vitals_test_hybrid.csv')

#### Data Standardization

In [23]:
def standardize_gru(vitals_train, vitals_dev, vitals_test):
    idx = pd.IndexSlice
    X_train, X_dev, X_test = vitals_train.copy(), vitals_dev.copy(), vitals_test.copy()

    # Min-Max Scaling
    train_min = X_train.loc[:, idx[:, 'mean']].min()
    train_max = X_train.loc[:, idx[:, 'mean']].max()
    for df in [X_train, X_dev, X_test]:
        df.loc[:, idx[:, 'mean']] = minmax_scaling(df.loc[:, idx[:, 'mean']], train_min, train_max)

    # Standardization
    X_train.loc[:, idx[:, 'time_since_measured']] = np.where(X_train.loc[:, idx[:, 'time_since_measured']] == 100, 0,
                                                             X_train.loc[:, idx[:, 'time_since_measured']])
    train_mean = X_train.loc[:, idx[:, 'time_since_measured']].mean()
    train_std = X_train.loc[:, idx[:, 'time_since_measured']].std()
    for df in [X_train, X_dev, X_test]:
        df.loc[:, idx[:, 'time_since_measured']] = standardize_time_since_measured(
            df.loc[:, idx[:, 'time_since_measured']], train_mean, train_std)

    return X_train, X_dev, X_test

In [24]:
# vitals_train_std, vitals_dev_std, vitals_test_std = standardize_gru(vitals_train_imputed, vitals_dev_imputed, vitals_test_imputed)

In [25]:
# # log the datasets
# save_to_pickle(vitals_train_std, os.path.join(LOG_DATA_DIR, f'vitals_train_gru_std.pkl'))
# save_to_pickle(vitals_dev_std, os.path.join(LOG_DATA_DIR, f'vitals_dev_gru_std.pkl'))
# save_to_pickle(vitals_test_std, os.path.join(LOG_DATA_DIR, f'vitals_test_gru_std.pkl'))

In [26]:
highly_corr_drop = ['dnr', 'fullcode', 'cmo', 'cmo_last']
patients_train.drop(columns=highly_corr_drop, inplace=True)
patients_dev.drop(columns=highly_corr_drop, inplace=True)
patients_test.drop(columns=highly_corr_drop, inplace=True)

In [27]:
vitals_to_drop = ['alanine aminotransferase', 'co2', 'co2 (etco2, pco2, etc.)', 'blood urea nitrogen', 'cardiac output thermodilution', 'chloride', 'cholesterol ldl', 'hematocrit', 'red blood cell count', 'lactic acid', 'mean corpuscular volume', 'phosphorous', 'positive end-expiratory pressure', 'potassium serum', 'prothrombin time pt', 'tidal volume set']

In [28]:
vitals_train_std.drop(columns=vitals_to_drop, level=0, inplace=True)
vitals_dev_std.drop(columns=vitals_to_drop, level=0, inplace=True)
vitals_test_std.drop(columns=vitals_to_drop, level=0, inplace=True)

NameError: name 'vitals_train_std' is not defined

In [None]:
vitals_train_std

### Create Feature Matrix

In [None]:
# X_train_merged = create_feature_matrix(patients_train, vitals_train_std, interv_train)
# X_dev_merged = create_feature_matrix(patients_dev, vitals_dev_std, interv_dev)
# X_test_merged = create_feature_matrix(patients_test, vitals_test_std, interv_test)
# blind_display(X_train_merged, X_dev_merged, X_test_merged, blinded=BLINDED)

### Store Processed Data

In [29]:
vitals_train

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,LEVEL2,alanine aminotransferase,alanine aminotransferase,alanine aminotransferase,albumin,albumin,albumin,albumin ascites,albumin ascites,albumin ascites,albumin pleural,...,white blood cell count,white blood cell count urine,white blood cell count urine,white blood cell count urine,ph,ph,ph,ph urine,ph urine,ph urine
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Aggregation Function,count,mean,std,count,mean,std,count,mean,std,count,...,std,count,mean,std,count,mean,std,count,mean,std
subject_id,hadm_id,icustay_id,hours_in,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2
9,150750,220597,0,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,1.0,8.0,
9,150750,220597,1,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
9,150750,220597,2,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
9,150750,220597,3,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
9,150750,220597,4,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,3.0,7.39,0.0,0.0,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,137810,229633,43,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
99995,137810,229633,44,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
99995,137810,229633,45,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,
99995,137810,229633,46,0.0,,,0.0,,,0.0,,,0.0,...,,0.0,,,0.0,,,0.0,,


In [33]:
save_to_pickle(df=vitals_train_imputed, filename=os.path.join(DATA_DIR, 'Vitals_train_raw.pkl'))
save_to_pickle(df=vitals_dev_imputed, filename=os.path.join(DATA_DIR, 'Vitals_dev_raw.pkl'))
save_to_pickle(df=vitals_test_imputed, filename=os.path.join(DATA_DIR, 'Vitals_test_raw.pkl'))

# save_to_pickle(df=Ys_train, filename=os.path.join(DATA_DIR, 'Y_train_gru.pkl'))
# save_to_pickle(df=Ys_dev, filename=os.path.join(DATA_DIR, 'Y_dev_gru.pkl'))
# save_to_pickle(df=Ys_test, filename=os.path.join(DATA_DIR, 'Y_test_gru.pkl'))