### Configuration

In [None]:
import os

num_bootstraps = 10
bootstrap_folder = os.path.expanduser("~/dropbox/sts-data/bootstraps")

In [None]:
import sys
import math
import numpy as np
import pandas as pd
from ml4cvd.arguments import parse_args
from ml4cvd.explorations import explore
from typing import List, Tuple, Union

def print_dataframe(df):
    """
    Display entire dataframe, be careful of printing very large dataframes
    """
    with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', None, 'display.max_colwidth', -1):
        print(df)

### Load data

In [None]:
sys.argv = f"""
.
--tensors /data/ecg/mgh
--sample_csv {os.path.expanduser("~/dropbox/sts-data/mgh-all-features-labels.csv")}
--input_tensors
    ecg_patientid_clean_sts_newest
    ecg_age_sts_newest
    ecg_sex_sts_newest
    ecg_rate_md_sts_newest
    sts_death
--explore_stratify_label
    sts_death_sts_death
--output_folder /tmp
--id explore
""".split()
args = parse_args()
df = explore(args, save_output=False)

##### Quantiles are determined such that each unique combination of stratification labels has at least 3 patients (at least 1 per train/valid/test). These quantiles are tuned by manually adjusting and checking the resulting groups.

In [None]:
# Isolate patients >= 21, exclude bad ECGs, bin continuous values
df['ecg_patientid_clean_sts_newest'] = df['ecg_patientid_clean_sts_newest'].astype(int)
df = df[df['ecg_age_sts_newest'] >= 21]
bad = pd.read_csv(os.path.expanduser('~/dropbox/sts-data/mgh-bad-ecgs.csv'))
bad = bad[bad['Problem'] != 'None']
df = df.merge(bad, how='outer', left_on='ecg_patientid_clean_sts_newest', right_on='MRN', indicator=True)
df = df[df['_merge'] == 'left_only']


df['mrn'] = df['ecg_patientid_clean_sts_newest'].astype(int)
df['death'] = df['sts_death_sts_death'].astype(int)
df['sex-male'] = df['ecg_sex_sts_newest_male'].astype(int)
df['age-quartile'], age_bins = pd.qcut(df['ecg_age_sts_newest'], 4, retbins=True, labels=[0,1,2,3])
df['heart-rate-tertile'], hr_bins = pd.qcut(df['ecg_rate_md_sts_newest'], 3, retbins=True, labels=[0,1,2])
df = df[['mrn', 'death', 'sex-male', 'age-quartile', 'heart-rate-tertile']]

# print_dataframe(df.groupby(['death', 'sex-male', 'age-quartile', 'heart-rate-tertile']).size())

### Stratify across train, valid, test splits

In [None]:
def train_valid_test_split(
    df: pd.DataFrame, 
    stratify_by: Union[str, List[str]], 
    test_ratio: float = 0.1, 
    valid_ratio: float = 0.2,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    assert not test_ratio + valid_ratio > 1

    relative_valid_ratio = valid_ratio / (1 - test_ratio)

    test_dfs = []
    valid_dfs = []
    train_dfs = []

    for name, group in df.groupby(stratify_by):
        tot = len(group)
        n = int(test_ratio * tot) or 1
        test_df = group.sample(n=n, replace=False)
        group = group.drop(test_df.index)
        tot = len(group)
        n = int(relative_valid_ratio * tot) or 1
        valid_df = group.sample(n=n, replace=False)
        train_df = group.drop(valid_df.index)
        
        # Assert that group is represented in all splits
        assert len(test_df) != 0
        assert len(valid_df) != 0
        assert len(train_df) != 0
        test_dfs.append(test_df)
        valid_dfs.append(valid_df)
        train_dfs.append(train_df)

    test_df = pd.concat(test_dfs)
    valid_df = pd.concat(valid_dfs)
    train_df = pd.concat(train_dfs)
    
    # Assert that all groups are represented in all splits
    train_groups = train_df.groupby(stratify_by)
    valid_groups = valid_df.groupby(stratify_by)
    test_groups = test_df.groupby(stratify_by)
    assert len(train_groups) == len(valid_groups) == len(test_groups)
    
    # Assert that split data exactly matches original data
    cols = list(df.columns)
    assert df.sort_values(cols).equals(pd.concat([train_df, valid_df, test_df]).sort_values(cols))

    train_df = train_df.sample(frac=1).reset_index(drop=True)
    valid_df = valid_df.sample(frac=1).reset_index(drop=True)
    test_df = test_df.sample(frac=1).reset_index(drop=True)
    return train_df, valid_df, test_df

### Stratify for bootstraps

In [None]:
for i in range(num_bootstraps):
    train, valid, test = train_valid_test_split(
        df=df, 
        stratify_by=['death', 'sex-male', 'age-quartile', 'heart-rate-tertile'], 
        test_ratio=0.1, 
        valid_ratio=0.2,
    )
    this_bootstrap = os.path.join(bootstrap_folder, str(i))
    os.makedirs(this_bootstrap, exist_ok=True)
    train.to_csv(os.path.join(this_bootstrap, "train.csv"), index=False)
    valid.to_csv(os.path.join(this_bootstrap, "valid.csv"), index=False)
    test.to_csv(os.path.join(this_bootstrap, "test.csv"), index=False)

### Report distribution of each stratify label

In [None]:
def print_label_prevalence(train: pd.DataFrame, valid: pd.DataFrame, test: pd.DataFrame, label: str):
    concat = pd.concat([train, valid, test], keys=['train', 'valid', 'test']).reset_index(0).rename({'level_0': 'split'}, axis=1)
    concat['split'] = pd.Categorical(concat['split'], ["train", "valid", "test"])
    grouped = concat.groupby([label, 'split']).size()
    print_dataframe(grouped.groupby(level=1).apply(lambda x: 100 * x / float(x.sum())))
    print()

In [None]:
for label in ['death', 'sex-male', 'age-quartile', 'heart-rate-tertile']:
    print_label_prevalence(train, valid, test, label)