# Stratify STS ECGs Across Bootstraps

This notebook has two parts:
1. [Compute ECG similarity/distances](#Compute-ECG-similarity)
2. [Stratify ECGs across bootstraps](#Stratify-ECGs)

## Compute ECG similarity

Calculate frechet distances as metric of ECG similarity. Run this section if `died_distances.csv` and `lived_distances.csv` does not exist.

In [None]:
! pip install similaritymeasures ~/repos/ml

import os
import h5py
import time
import numpy as np
import pandas as pd
import multiprocessing as mp

from typing import List, Tuple, Set
from distance import compute_vcg, frechet_distance
from ml4cvd.tensor_maps_ecg import build_cardiac_surgery_tensor_maps

MRN = int
Distance = float

### helper functions

In [None]:
tm = build_cardiac_surgery_tensor_maps(['12_lead_ecg_2500_std_newest_sts'])['12_lead_ecg_2500_std_newest_sts']

def get_ecg(mrn: MRN) -> np.array:
    path = os.path.join('/data/ecg/mgh/', f'{mrn}.hd5')

    with h5py.File(path, 'r') as hd5:
        return tm.postprocess_tensor(tm.tensor_from_file(tm, hd5), False, hd5)

def get_distance(mrn: MRN) -> Distance:
    return frechet_distance(base_vcg, compute_vcg(get_ecg(mrn), spherical_coordinates=False))

def get_distances(mrns: List[MRN]) -> List[Tuple[MRN, Distance]]:
    base_idx = np.random.randint(0, len(mrns))
    base_mrn = mrns[base_idx]
    base_ecg = get_ecg(base_mrn)
    global base_vcg
    base_vcg = compute_vcg(base_ecg, spherical_coordinates=False)

    mrns = np.delete(mrns, base_idx)

    with mp.Pool(processes=mp.cpu_count()) as pool:
        distances = pool.map(get_distance, mrns)

    distances = list(zip(mrns, distances))
    distances.append((base_mrn, 0))

    return distances

### load initial data

In [None]:
df = pd.read_csv(os.path.expanduser('~/dropbox/sts_ecg/xref_adults/list_1_in_all_windows.csv'))
df = df.sort_values(['medrecn', 'surgdt']).drop_duplicates(subset='medrecn', keep='last') # remove duplicate surgeries

died = df[df['mtopd'] == 1]['medrecn'].to_numpy()
lived = df[df['mtopd'] == 0]['medrecn'].to_numpy()

### compute frechet distances

The cell below calculates the frechet distances between ecgs.
This is extremely slow and takes approximately 4 hours on a 20 core CPU.
When complete, two `csv` files are created at the user's home directory:

    died_distances.csv
    lived_distances.csv


The `%%capture` line magic should captures the output of the cell even when the browser tab with the notebook is closed.

In [None]:
%%capture output

t0 = time.time()

os.makedirs(os.path.expanduser('~/dropbox/sts_ecg/distances'), exist_ok=True)

died_distances = get_distances(died)
pd.DataFrame(died_distances, columns=['mrn', 'distance']).to_csv(os.path.expanduser('~/dropbox/sts_ecg/distances/died_distances.csv'), index=False)

lived_distances = get_distances(lived)
pd.DataFrame(lived_distances, columns=['mrn', 'distance']).to_csv(os.path.expanduser('~/dropbox/sts_ecg/distances/lived_distances.csv'), index=False)

t1 = time.time()

print(f'Computed frechet distances for {len(died) + len(lived)} ECGs in {t1-t0:.0f} seconds')

In [None]:
output.show()

## Stratify ECGs

Stratify ECGs into train/valid/test splits across bootstraps such that death label prevalence is preserved and similar ECGs appear in each split.

In [None]:
import os
import numpy as np
import modin.pandas as pd

In [None]:
num_bootstraps = 10

In [None]:
# Get map between MRNs and new de-identified HD5 names
path_to_deid_map = "~/dropbox/ecg/mgh-ecg-deid-map.csv"
df_deid_map = pd.read_csv(path_to_deid_map)
df_deid_map = df_deid_map.set_index('mrn')['new_id'].to_dict()
print(len(df_deid_map))

In [None]:
# Function to convert a list of MRNs to its deidentified versions
def deid_mrns(mrns: list, map: dict) -> list:
    return [map[mrn] for mrn in mrns]

In [None]:
# Remove bad ECGs
lived = pd.read_csv(os.path.expanduser('~/dropbox/sts_ecg/distances/lived_distances.csv'))
died = pd.read_csv(os.path.expanduser('~/dropbox/sts_ecg/distances/died_distances.csv'))

bad = pd.read_csv(os.path.expanduser('~/dropbox/sts_data/mgh-bad-ecgs.csv'))
bad = bad[bad['Problem'] != 'None']

lived = lived.merge(bad, how='outer', left_on='mrn', right_on='MRN', indicator=True)
lived = lived[lived['_merge'] == 'left_only']

died = died.merge(bad, how='outer', left_on='mrn', right_on='MRN', indicator=True)
died = died[died['_merge'] == 'left_only']

lived = lived.sort_values('distance')['mrn'].astype(int)
died = died.sort_values('distance')['mrn'].astype(int)

In [None]:
# assert mrns are unique
assert len(lived) + len(died) == len(set(lived) | set(died))

def stratify(batch_size, valid_idxs, test_idxs, mrns):
    train, valid, test = [], [], []
    for i, mrn in enumerate(mrns):
        idx = i % batch_size
        if idx in valid_idxs:
            valid.append(mrn)
        elif idx in test_idxs:
            test.append(mrn)
        else:
            train.append(mrn)
    return train, valid, test

bootstraps = []
bootstraps_deid = []
for i in range(num_bootstraps):
    valid_idxs = set([(i+1)%10, (i+2)%10])
    test_idxs = set([i%10])
    died_train, died_valid, died_test = stratify(num_bootstraps, valid_idxs, test_idxs, died)
    lived_train, lived_valid, lived_test = stratify(num_bootstraps, valid_idxs, test_idxs, lived)
    
    died_train_deid = deid_mrns(mrns=died_train, map=df_deid_map)
    died_valid_deid = deid_mrns(mrns=died_valid, map=df_deid_map)
    died_test_deid  = deid_mrns(mrns=died_test, map=df_deid_map)

    lived_train_deid = deid_mrns(mrns=lived_train, map=df_deid_map)
    lived_valid_deid = deid_mrns(mrns=lived_valid, map=df_deid_map)
    lived_test_deid = deid_mrns(mrns=lived_test, map=df_deid_map)
    
    train = pd.DataFrame(np.append(died_train, lived_train), columns=['mrn'])
    valid = pd.DataFrame(np.append(died_valid, lived_valid), columns=['mrn'])
    test = pd.DataFrame(np.append(died_test, lived_test), columns=['mrn'])
    bootstraps.append((train, valid, test))
    
    train_deid = pd.DataFrame(np.append(died_train_deid, lived_train_deid), columns=['mrn'])
    valid_deid = pd.DataFrame(np.append(died_valid_deid, lived_valid_deid), columns=['mrn'])
    test_deid = pd.DataFrame(np.append(died_test_deid, lived_test_deid), columns=['mrn'])
    bootstraps_deid.append((train_deid, valid_deid, test_deid))

In [None]:
# Save bootstraps to CSV
base_path = os.path.expanduser('~/dropbox/sts_data/bootstraps')
for i, bootstrap in enumerate(bootstraps):
    i = str(i)
    os.makedirs(os.path.join(base_path, i), exist_ok=True)
    bootstrap[0].to_csv(os.path.join(base_path, i, 'train.csv'), index=False)
    bootstrap[1].to_csv(os.path.join(base_path, i, 'valid.csv'), index=False)
    bootstrap[2].to_csv(os.path.join(base_path, i, 'test.csv'), index=False)
    print(f"Saved bootstrap {i} to CSV files")

In [None]:
# Save bootstraps to CSV: de-id
base_path = os.path.expanduser('~/dropbox/sts_data/bootstraps_deid')
for i, bootstrap in enumerate(bootstraps_deid):
    i = str(i)
    os.makedirs(os.path.join(base_path, i), exist_ok=True)
    bootstrap[0].to_csv(os.path.join(base_path, i, 'train.csv'), index=False)
    bootstrap[1].to_csv(os.path.join(base_path, i, 'valid.csv'), index=False)
    bootstrap[2].to_csv(os.path.join(base_path, i, 'test.csv'), index=False)
    print(f"Saved bootstrap (de-id) {i} to CSV files")