# ECG Latent Space Exploration
This notebook shows how to create a latent space from a pretrained multimodal model.

In [None]:
import os
import sys
from typing import Callable, List, Dict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split


from torch.utils.data import DataLoader

import tensorflow as tf
from tensorflow.keras.losses import CategoricalCrossentropy

from tensorflow.keras.optimizers import Adam

from ml4h_ccds.data_descriptions.ecg import ECGDataDescription
from ml4h_ccds.data_descriptions.wide_file import WideFileDataDescription
from ml4h_ccds.data_descriptions.util import download_s3_if_not_exists

from ml4h.models.model_factory import block_make_multimodal_multitask_model
from ml4h.TensorMap import TensorMap, Interpretation

# ml4h Imports
from ml4h.arguments import parse_args
from ml4h.metrics import coefficient_of_determination
from ml4h.explorations import latent_space_dataframe
from ml4h.models.model_factory import block_make_multimodal_multitask_model


from ml4ht.data.util.date_selector import DateRangeOptionPicker, first_dt, DATE_OPTION_KEY, DateRangeOptionPicker
from ml4ht.data.util.data_frame_data_description import DataFrameDataDescription 
from ml4ht.data.data_description import DataDescription
from ml4ht.data.sample_getter import DataDescriptionSampleGetter
from ml4ht.data.explore import explore_data_descriptions, explore_sample_getter
from ml4ht.data.data_loader import SampleGetterDataset, numpy_collate_fn, SampleGetterIterableDataset

import math
import time
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
from scipy import stats
from sklearn.linear_model import LogisticRegression, LinearRegression, ElasticNet, Ridge
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.metrics import brier_score_loss, precision_score, recall_score, f1_score, roc_auc_score

from ml4h.explorations import latent_space_dataframe

# IPython imports
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import colors

In [None]:
SESSION_DIR = os.path.expanduser("~")  # downloaded data will be stored here

# ECG data description

In [None]:
def normalize_ecg(ecg, _):
    """Transform ECG units to millivolts"""
    return ecg / 1000

def standardize_by_sample_ecg(ecg, _):
    """Transform ECG units to millivolts"""
    return (ecg - np.mean(ecg)) / (np.std(ecg) + 1e-6) 

ecg_dd_i = ECGDataDescription(
    SESSION_DIR, 
    name='input_ecg_5000_std_continuous', 
    ecg_len=5000,  # all ECGs will be linearly interpolated to be this length
    transforms=[standardize_by_sample_ecg],  # these will be applied in order
    # data will be automatically localized from s3
    s3_bucket_name='2017P001650', s3_bucket_path=['ecg_mgh_hd5s'],   # 'ecg_mgh_hd5s',  list of hd5s
)

ecg_dd_o = ECGDataDescription(
    SESSION_DIR, 
    name='output_ecg_5000_std_continuous', 
    ecg_len=5000,  # all ECGs will be linearly interpolated to be this length
    transforms=[standardize_by_sample_ecg],  # these will be applied in order
    # data will be automatically localized from s3
    s3_bucket_name='2017P001650', s3_bucket_path=['ecg_mgh_hd5s'],   # 'ecg_mgh_hd5s',  list of hd5s
)
ecg_dd_i_lead_I = ECGDataDescription(
    SESSION_DIR, 
    name='input_ecg_lead_I_5000_std_continuous', 
    ecg_len=5000,  # all ECGs will be linearly interpolated to be this length
    transforms=[standardize_by_sample_ecg],  # these will be applied in order
    leads={'I':0},
    # data will be automatically localized from s3
    s3_bucket_name='2017P001650', s3_bucket_path=['ecg_mgh_hd5s'],   # 'ecg_mgh_hd5s',  list of hd5s
)

ecg_dd_o_lead_I = ECGDataDescription(
    SESSION_DIR, 
    name='output_ecg_lead_I_5000_std_continuous', 
    ecg_len=5000,  # all ECGs will be linearly interpolated to be this length
    transforms=[standardize_by_sample_ecg],  # these will be applied in order
    leads={'I':0},
    # data will be automatically localized from s3
    s3_bucket_name='2017P001650', s3_bucket_path=['ecg_mgh_hd5s'],   # 'ecg_mgh_hd5s',  list of hd5s
)
ecg_dd_i_lead_I_mv = ECGDataDescription(
    SESSION_DIR, 
    name='input_ecg_lead_I_5000_mv_continuous', 
    ecg_len=5000,  # all ECGs will be linearly interpolated to be this length
    transforms=[normalize_ecg],  # these will be applied in order
    leads={'I':0},
    # data will be automatically localized from s3
    s3_bucket_name='2017P001650', s3_bucket_path=['ecg_mgh_hd5s'],   # 'ecg_mgh_hd5s',  list of hd5s
)

ecg_dd_o_lead_I_mv = ECGDataDescription(
    SESSION_DIR, 
    name='output_ecg_lead_I_5000_mv_continuous', 
    ecg_len=5000,  # all ECGs will be linearly interpolated to be this length
    transforms=[normalize_ecg],  # these will be applied in order
    leads={'I':0},
    # data will be automatically localized from s3
    s3_bucket_name='2017P001650', s3_bucket_path=['ecg_mgh_hd5s'],   # 'ecg_mgh_hd5s',  list of hd5s
)

In [None]:
ecg_dd_i_bwh = ECGDataDescription(
    SESSION_DIR, 
    name='input_ecg_5000_std_continuous', 
    ecg_len=5000,  # all ECGs will be linearly interpolated to be this length
    transforms=[standardize_by_sample_ecg],  # these will be applied in order
    # data will be automatically localized from s3
    s3_bucket_name='2017P001650', s3_bucket_path=['ecg_bwh_hd5s'],   # 'ecg_mgh_hd5s',  list of hd5s
)
ecg_dd_i_bwh_lead_I = ECGDataDescription(
    SESSION_DIR, 
    name='input_ecg_lead_I_5000_std_continuous', 
    ecg_len=5000,  # all ECGs will be linearly interpolated to be this length
    transforms=[standardize_by_sample_ecg],  # these will be applied in order
    leads={'I':0},
    # data will be automatically localized from s3
    s3_bucket_name='2017P001650', s3_bucket_path=['ecg_bwh_hd5s'],   # 'ecg_mgh_hd5s',  list of hd5s
)
ecg_dd_i_bwh_lead_I_mv = ECGDataDescription(
    SESSION_DIR, 
    name='input_ecg_lead_I_5000_mv_continuous', 
    ecg_len=5000,  # all ECGs will be linearly interpolated to be this length
    transforms=[normalize_ecg],  # these will be applied in order
    leads={'I':0},
    # data will be automatically localized from s3
    s3_bucket_name='2017P001650', s3_bucket_path=['ecg_bwh_hd5s'],   # 'ecg_mgh_hd5s',  list of hd5s
)

In [None]:
# Copy pasted from ml4h branch nd_ml4ht_integration
def _not_implemented_tensor_from_file(_, __, ___=None):
    """Used to make sure TensorMap is never used to load data"""
    raise NotImplementedError('This TensorMap cannot load data.')
    

def tensor_map_from_data_description(
        data_description: DataDescription,
        interpretation: Interpretation,
        shape,
        name=None,
        **tensor_map_kwargs,
) -> TensorMap:
    """
    Allows a DataDescription to be used in the model factory
    by converting a DataDescription into a TensorMap
    """
    tmap = TensorMap(
        name=name if name else data_description.name,
        interpretation=interpretation,
        shape=shape,
        tensor_from_file=_not_implemented_tensor_from_file,
        **tensor_map_kwargs,
    )
    return tmap


ecg_tmap = tensor_map_from_data_description(
    ecg_dd_i,
    Interpretation.CONTINUOUS,
    (5000, 12), name='ecg_5000_std'
)

In [None]:
import sys
from ml4h.arguments import parse_args
from ml4h.metrics import coefficient_of_determination
from ml4h.explorations import latent_space_dataframe
from ml4h.models.model_factory import block_make_multimodal_multitask_model


model_name = 'ecg_2500_autoencoder_mgh_c3po_128d_v2021_12_17'
model_name = 'mgh_ecg_2500_std_autoencoder_v2022_03_29'
sys.argv = ['train',
            '--activation', 'mish',
            '--block_size', '8', '--conv_width', '61', '--conv_layers', '64', '64', '--pool_type', 'max', 
            '--dense_blocks', '64', '64', '--dense_layers', '128', '--pool_type', 'average',
            '--learning_rate', '0.00002',
            '--encoder_blocks', 'conv_encode', '--merge_blocks', '--decoder_blocks', 'conv_decode',
           '--tensormap_prefix', 'ml4h.tensormap.mgb.ecg',
            '--model_file', f'../../ecg_rest_to_ecg_median_translator_256d.h5',
           
           ]
args = parse_args()
args.tensor_maps_in = [ecg_tmap]
args.tensor_maps_out = [ecg_tmap]

ecg_10s_2_median, _, _, _ = block_make_multimodal_multitask_model(**args.__dict__)

In [None]:
import sys
from ml4h.arguments import parse_args
from ml4h.metrics import coefficient_of_determination
from ml4h.explorations import latent_space_dataframe
from ml4h.models.model_factory import block_make_multimodal_multitask_model

model_name = 'mgh_ecg_rest_median_raw_10_autoencoder_256d_v2022_04_13'


sys.argv = ['train',
            '--activation', 'mish',
            '--block_size', '8', '--conv_width', '61', '--conv_layers', '64', '64', '--pool_type', 'max', 
            '--dense_blocks', '64', '64', '--dense_layers', '128', '--pool_type', 'average',
            '--learning_rate', '0.00002',
            '--encoder_blocks', 'conv_encode', '--merge_blocks', '--decoder_blocks', 'conv_decode',
           '--tensormap_prefix', 'ml4h.tensormap.mgb.ecg',
            '--model_file', f'../../trained_models/{model_name}/{model_name}.h5',
           
           ]
args = parse_args()


ecg_tmap_median = tensor_map_from_data_description(
    ecg_dd_i,
    Interpretation.CONTINUOUS,
    (600, 12), name='ecg_rest_median_raw_10'
)
args.tensor_maps_in = [ecg_tmap_median]
args.tensor_maps_out = [ecg_tmap_median]

ecg_autoencoder, encoders, decoders, merger = block_make_multimodal_multitask_model(**args.__dict__)

In [None]:
import sys
from ml4h.arguments import parse_args
from ml4h.metrics import coefficient_of_determination
from ml4h.explorations import latent_space_dataframe
from ml4h.models.model_factory import block_make_multimodal_multitask_model

model_name = 'mgh_ecg_rest_median_raw_10_lead_I_autoencoder_256d_v2022_04_09'


sys.argv = ['train',
            '--activation', 'mish',
            '--block_size', '8', '--conv_width', '61', '--conv_layers', '64', '64', '--pool_type', 'max', 
            '--dense_blocks', '64', '64', '--dense_layers', '128', '--pool_type', 'average',
            '--learning_rate', '0.00002',
            '--encoder_blocks', 'conv_encode', '--merge_blocks', '--decoder_blocks', 'conv_decode',
           '--tensormap_prefix', 'ml4h.tensormap.mgb.ecg',
            '--model_file', f'../../trained_models/{model_name}/{model_name}.h5',
           
           ]
args = parse_args()


ecg_tmap_median_lead_I = tensor_map_from_data_description(
    ecg_dd_i,
    Interpretation.CONTINUOUS,
    (600, 1), name='ecg_rest_median_raw_10_lead_I'
)
args.tensor_maps_in = [ecg_tmap_median_lead_I]
args.tensor_maps_out = [ecg_tmap_median_lead_I]

_, encoders_lead_I, _, _ = block_make_multimodal_multitask_model(**args.__dict__)

In [None]:
%matplotlib inline
mrn = 1519973
mrn = 5212097
mrn=4719681
#mrn=4282470
options = ecg_dd_i.get_loading_options(mrn)

plt.plot(np.linspace(0, 10, 5000), ecg_dd_i.get_raw_data(mrn, options[-1]))
plt.xlabel("time (s)")
plt.ylabel("amplitude (mV)")
plt.show()

options = ecg_dd_i.get_loading_options(mrn)
example = ecg_dd_i.get_raw_data(mrn, options[-1])
example.shape

In [None]:
mecg = ecg_10s_2_median(np.array([example]))

In [None]:
plt.plot(np.linspace(0, 10, 600), mecg[0])
plt.xlabel("time (s)")
plt.ylabel("amplitude (mV)")
plt.show()

In [None]:
leads = ['I', 'aVR', 'V1', 'V4', 
             'II', 'aVL', 'V2', 'V5', 
             'III', 'aVF', 'V3', 'V6', ]
    
channel_map = {
    'I': 0, 'II': 1, 'III': 2, 'V1': 3, 'V2': 4, 'V3': 5,
    'V4': 6, 'V5': 7, 'V6': 8, 'aVF': 9, 'aVL': 10, 'aVR': 11,
}
fig, axes = plt.subplots(3, 4, figsize=(12, 4), dpi=300, sharey=False, sharex=True)
for i, (ax, lead) in enumerate(zip(axes.ravel(), leads)):
    ax.plot(range(600), mecg[0, :, channel_map[lead]])
    ax.set_title(f"Lead: {lead}")
    ax.set_xlabel("time (s)")
    ax.set_ylabel("amplitude (mV)")
plt.tight_layout()

In [None]:
leads = ['I', 'aVR', 'V1', 'V4', 
             'II', 'aVL', 'V2', 'V5', 
             'III', 'aVF', 'V3', 'V6', ]
    
channel_map = {
    'I': 0, 'II': 1, 'III': 2, 'V1': 6, 'V2': 7, 'V3': 8,
    'V4': 9, 'V5': 10, 'V6': 11, 'aVF': 5, 'aVL': 4, 'aVR': 3,
}
fig, axes = plt.subplots(3, 4, figsize=(12, 4), dpi=300, sharey=False, sharex=True)
for i, (ax, lead) in enumerate(zip(axes.ravel(), leads)):
    ax.plot(range(2000,3500), example[2000:3500, channel_map[lead]])
    ax.set_title(f"Lead: {lead}")
    ax.set_xlabel("time (s)")
    ax.set_ylabel("amplitude (mV)")
plt.tight_layout()

In [None]:
import os
import sys
import math
import argparse
import numpy as np
import pandas as pd
from collections import Counter, defaultdict
from scipy import ndimage
from biosppy.signals.ecg import ecg
# Keras imports
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K

# IPython imports
from IPython.display import Image
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
def stretch_ecg(x, n=0):
    """
    stretches input ECG to n bpm
    """
    out = ecg(x.copy(), sampling_rate=500, show=False)
    hr = out[-1].mean()
    t = np.arange(len(x))
    if n == 0:
        tp = np.arange(len(x))
    else:
        tp = np.arange(len(x)) * n / hr
    stretched = np.interp(tp, t, x)
    print(f'{x.shape} and {stretched.shape} and hr: {hr} and n{n}')
    out2 = ecg(stretched, show=False)
    return stretched, out2[2]

def plot_biosspy(mrn, bpm = 60):
    options = ecg_dd_i.get_loading_options(mrn)
    example = ecg_dd_i.get_raw_data(mrn, options[-1])
    fig, axes = plt.subplots(3, 4, figsize=(48, 16), dpi=300, sharey=False, sharex=True)
    for i, (ax, lead) in enumerate(zip(axes.ravel(), leads)):
        stretched, peaks = stretch_ecg(example[:,channel_map[lead]], bpm)
        #print(f'at lead {lead} peaks are: {peaks}')
        ax.plot(range(0,5000), stretched)
        for p in peaks:
            ax.axvline(p, linestyle='dashed', c='orange')
        ax.set_title(f"Lead: {lead}")
        ax.set_xlabel("time (s)")
        ax.set_ylabel("amplitude (mV)")
    plt.tight_layout()
    
    
    
def plot_biosspy_median(mrn, median_size = 600, bpm = 0):
    options = ecg_dd_i.get_loading_options(mrn)
    example = ecg_dd_i.get_raw_data(mrn, options[-1])
    
    medians = np.zeros((median_size, len(channel_map)))
    for i,lead in enumerate(leads):
        waves = []
        stretched, peaks = stretch_ecg(example[:,channel_map[lead]], bpm)
        #print(f'at lead {lead} peaks are: {peaks}')

        for j, p0 in enumerate(peaks[:-2]):
            p11 = peaks[j+1]
            p22 = peaks[j+2]
            middle = (p0+p11)//2

            #waves.append(np.interp(np.arange(median_size), np.arange(p22-middle), stretched[middle:p22]))
            waves.append(stretched[middle:middle+median_size])
        waves = np.array(waves)
        #print(f'{waves.shape}')
        medians[:, channel_map[lead]] = np.median(waves, axis=0)
    fig, axes = plt.subplots(3, 4, figsize=(12, 4), dpi=300, sharey=False, sharex=True)
    for i, (ax, lead) in enumerate(zip(axes.ravel(), leads)):
        ax.plot(range(median_size), medians[:, channel_map[lead]])
        ax.set_title(f"Lead: {lead}")
        ax.set_xlabel("time (s)")
        ax.set_ylabel("amplitude (mV)")    
    plt.tight_layout()
    plt.show()

In [None]:
plot_biosspy_median(1519973, bpm = 60)

In [None]:
plot_biosspy_median(4282470, bpm = 0)

In [None]:
#plot_biosspy_median(4282470, bpm=0)
# plot_biosspy_median(4719681)
# plot_biosspy_median(1519973)
# plot_biosspy_median(4282470)
plot_biosspy_median(1519973, bpm = 0)


In [None]:
mrn = 1519973
#mrn = 5212097
#mrn=4719681
#plot_biosspy(5212097)
# plot_biosspy(4719681)
#plot_biosspy(1519973)

plot_biosspy(1519973)

In [None]:
bpm = 60
stretched, peaks = stretch_ecg(example[:,0], bpm)
print(f'peaks are: {peaks}')

In [None]:
fig, axes = plt.subplots(3, 4, figsize=(12, 4), dpi=300, sharey=False, sharex=True)
for i, (ax, lead) in enumerate(zip(axes.ravel(), leads)):
    stretched, peaks = stretch_ecg(example[:,channel_map[lead]], bpm)
    print(f'at lead {lead} peaks are: {peaks}')
    ax.plot(range(0,5000), stretched)
    for p in peaks:
        ax.axvline(p, linestyle='dashed', c='orange')
    ax.set_title(f"Lead: {lead}")
    ax.set_xlabel("time (s)")
    ax.set_ylabel("amplitude (mV)")
plt.tight_layout()

In [None]:
# download the wide file
wide_path = download_s3_if_not_exists(
    bucket_name='2017P001650',
    bucket_path='csvs/charge_set_plusothers_2021_08_10_mgh_and_bwh.tsv',
    local_dir=SESSION_DIR,
)

In [None]:
# load the wide file
wide_df = pd.read_csv(wide_path, sep="\t")

In [None]:
wide_df.dropna(subset=["MRN"], inplace=True)
wide_df["MRN"] = wide_df["MRN"].astype(int)
wide_df["fpath_x"] = wide_df["fpath_x"].astype(int, errors='ignore')
wide_df["fpath_y"] = wide_df["fpath_y"].astype(int, errors='ignore')

In [None]:
wide_df.set_index("MRN", inplace=True)

wide_df = wide_df[~wide_df.index.duplicated()]

wide_df["af_age"] = pd.to_timedelta(wide_df["af_age"], unit="d")

wide_df["last_encounter"] = pd.to_timedelta(wide_df["last_encounter"],)
wide_df["start_fu_date"] = pd.to_datetime(wide_df["start_fu_date"],)

wide_df["start_fu_age"] = pd.to_timedelta(wide_df["start_fu_age"])
wide_df["start_fu_age"] = pd.to_timedelta(wide_df["start_fu_age"])

In [None]:
def label_hospital(row):
    return 'BWH' if pd.isna(row['fpath_x']) else 'MGH'
wide_df['hospital'] = wide_df.apply(lambda row: label_hospital(row), axis=1)      

In [None]:
wide_df.head()

In [None]:
class ColumnFromWideFileDD(DataDescription):
    # DataDescription for a wide file

    def __init__(
        self,
        wide_df: pd.DataFrame,
        column: str,  # e.g. Dem.Gender.no_filter
        channel_map: Dict[str,int] = None,
        transform: Callable = None,
    ):
        """
        """
        self.wide_df = wide_df
        self.column = column
        self.channel_map = channel_map
        self.transform = transform

    def get_loading_options(self, sample_id):
        return []

    def get_raw_data(self, sample_id, loading_option):
        row = self.wide_df.loc[[sample_id]].iloc[0]
        value = row[self.column]
        if self.channel_map:
            tensor = np.zeros((len(self.channel_map),), dtype=np.float32)
            for cm in self.channel_map:
                if value.lower() == cm:
                    tensor[self.channel_map[cm]] = 1.0
        else:
            tensor = np.zeros((1,), dtype=np.float32)
            tensor[0] = float(value)
        if self.transform:
            return self.transform(tensor)
        return tensor
    
    @property
    def name(self):
        return self.column
    
class AgeAtECGWideFileDD(DataDescription):
    # DataDescription for a wide file

    def __init__(
        self,
        wide_df: pd.DataFrame,
        reference_date_column: str,  # e.g. start_fu
        reference_age_column: str,  # e.g. start_fu_age
    ):
        """
        """
        self.wide_df = wide_df
        self.reference_date_column = reference_date_column
        self.reference_age_column = reference_age_column

    def get_loading_options(self, sample_id):
        row = self.wide_df.loc[[sample_id]].iloc[0]
        start_fu_date = row[self.reference_date_column].to_pydatetime()
        loading_option = {DATE_OPTION_KEY: start_fu_date}
        return [loading_option]

    def get_raw_data(self, sample_id, loading_option):
        """expects time of ECG in the loading option as DATE_OPTION_KEY"""
        ecg_date = loading_option[DATE_OPTION_KEY]
        row = self.wide_df.loc[sample_id]
        ref_age = row[self.reference_age_column]
        ref_date = row[self.reference_date_column]
        age_at_ecg = ref_age - (ref_date - ecg_date)
        age_in_years = age_at_ecg.total_seconds()/31536000
        norm_age = age_in_years - 63.36
        norm_age /= 7.55
        return np.array(norm_age, dtype=np.float32) 
    
    @property
    def name(self):
        return "output_age_from_wide_csv_continuous"

In [None]:
# build the data description, and make sure it works!
sex_from_wide_dd = ColumnFromWideFileDD(
    wide_df=wide_df,
    column="Dem.Gender.no_filter",
    channel_map={'female': 0, 'male': 1}
)

def bmi_check(x):
    if x[0] > 50 or x[0] < 12:
        raise ValueError('bmi out of range')
    return x

wide_dds = []
wide_cols = [
     'MGH_MRN_0',
     'BWH_MRN_0',
]
for col in wide_cols:
    if col == 'start_fu_BMI':
        wide_dds.append(ColumnFromWideFileDD(wide_df=wide_df, column=col, transform=bmi_check))
    else:
        wide_dds.append(ColumnFromWideFileDD(wide_df=wide_df, column=col))

In [None]:
def option_picker(sample_id, dds):
#     import pdb; pdb.set_trace()
    start_fu = wide_df.loc[sample_id]["start_fu_date"]
    hospital = wide_df.loc[sample_id]["hospital"]
    bwh_long_name = "BRIGHAM & WOMEN'S/FAULKNER HOSP."
    min_ecg_dt = start_fu - pd.to_timedelta("3y")
    max_ecg_dt = start_fu
    ecg_dts = ecg_dd_i.get_loading_options(sample_id)
    ecg_dts = [
        {
            DATE_OPTION_KEY: option[DATE_OPTION_KEY],
            ecg_dd_i.S3_PATH_OPTION: option[ecg_dd_i.S3_PATH_OPTION],
        }
        for option in ecg_dts
        if min_ecg_dt <= option[DATE_OPTION_KEY] <= max_ecg_dt and \
             ((option['SITE'] == bwh_long_name and hospital == 'BWH') or \
              (hospital in option['SITE'] and hospital == 'MGH'))
    ]
    if not ecg_dts:
        raise ValueError("No dates available")
    #dt = np.random.choice(ecg_dts)
    dt = ecg_dts[-1]
#     summary = ecg_dd_i.get_summary_data(sample_id, dt)
#     if summary['num_zeros'] > 1500:
#         raise ValueError('too many zeros.')
  
    return {
        dd: dt
        for dd in dds
    }

def option_picker_bwh(sample_id, dds):
#     import pdb; pdb.set_trace()
    start_fu = wide_df.loc[sample_id]["start_fu_date"]
    hospital = wide_df.loc[sample_id]["hospital"]
    bwh_long_name = "BRIGHAM & WOMEN'S/FAULKNER HOSP."
    min_ecg_dt = start_fu - pd.to_timedelta("3y")
    max_ecg_dt = start_fu
    ecg_dts = ecg_dd_i_bwh.get_loading_options(sample_id)
    ecg_dts = [
        {
            DATE_OPTION_KEY: option[DATE_OPTION_KEY],
            ecg_dd_i_bwh.S3_PATH_OPTION: option[ecg_dd_i_bwh.S3_PATH_OPTION],
        }
        for option in ecg_dts
        if min_ecg_dt <= option[DATE_OPTION_KEY] <= max_ecg_dt and \
             ((option['SITE'] == bwh_long_name and hospital == 'BWH') or \
              (hospital in option['SITE'] and hospital == 'MGH'))
    ]
    if not ecg_dts:
        raise ValueError("No dates available")
    #dt = np.random.choice(ecg_dts)
    dt = ecg_dts[-1]
#     summary = ecg_dd_i.get_summary_data(sample_id, dt)
#     if summary['num_zeros'] > 1500:
#         raise ValueError('too many zeros.')
  
    return {
        dd: dt
        for dd in dds
    }

In [None]:
# This is how all of the components are merged together
sg = DataDescriptionSampleGetter(
    input_data_descriptions=[ecg_dd_i],  # what we want a model to use as input data
    output_data_descriptions=wide_dds,  # what we want a model to predict from the input data
    option_picker=option_picker,
)
# This is how all of the components are merged together
sg_lead_I = DataDescriptionSampleGetter(
    input_data_descriptions=[ecg_dd_i_lead_I],  # what we want a model to use as input data
    output_data_descriptions=wide_dds,  # what we want a model to predict from the input data
    option_picker=option_picker,
)

In [None]:
# This is how all of the components are merged together
sg_bwh = DataDescriptionSampleGetter(
    input_data_descriptions=[ecg_dd_i_bwh],  # what we want a model to use as input data
    output_data_descriptions=wide_dds,  # what we want a model to predict from the input data
    option_picker=option_picker_bwh,
)
sg_bwh_lead_I = DataDescriptionSampleGetter(
    input_data_descriptions=[ecg_dd_i_bwh_lead_I],  # what we want a model to use as input data
    output_data_descriptions=wide_dds,  # what we want a model to predict from the input data
    option_picker=option_picker_bwh,
)

In [None]:
sg_explore_df = pd.read_csv('../../af_survive_explore_all.csv')

In [None]:
working_ids = sg_explore_df[sg_explore_df["error"].isna()]["sample_id"]
dataset = SampleGetterIterableDataset(sample_ids=list(working_ids), sample_getter=sg,
                                           get_epoch=SampleGetterIterableDataset.shuffle_get_epoch)

dataloader = DataLoader(
    dataset, num_workers=14, collate_fn=numpy_collate_fn, batch_size=4,
)

dataset_lead_I = SampleGetterIterableDataset(sample_ids=list(working_ids), sample_getter=sg_lead_I,
                                           get_epoch=SampleGetterIterableDataset.shuffle_get_epoch)
dataloader_lead_I = DataLoader(
    dataset_lead_I, num_workers=12, collate_fn=numpy_collate_fn, batch_size=4, 
)

In [None]:
dataset_bwh = SampleGetterIterableDataset(sample_ids=list(working_ids), sample_getter=sg_bwh,
                                           get_epoch=SampleGetterIterableDataset.shuffle_get_epoch)

dataloader_bwh = DataLoader(
    dataset_bwh, num_workers=14, collate_fn=numpy_collate_fn, batch_size=32,
)
dataset_bwh_lead_I = SampleGetterIterableDataset(sample_ids=list(working_ids), sample_getter=sg_bwh_lead_I,
                                           get_epoch=SampleGetterIterableDataset.shuffle_get_epoch)
dataloader_bwh_lead_I = DataLoader(
    dataset_bwh_lead_I, num_workers=12, collate_fn=numpy_collate_fn, batch_size=4, 
)

In [None]:
from collections import defaultdict

def space_from_dataloader(dataloader, encoders, tensor_map, tensor_map_median, max_batches=25000):
    dataloader_iterator = iter(dataloader)
    space_dict = defaultdict(list)
    for i in range(max_batches):
        try:
            data, target = next(dataloader_iterator)
            median = ecg_10s_2_median(data[tensor_map.input_name()])
            if tensor_map_median.shape[-1] == 1:
                encoding = encoders[tensor_map_median].predict(median[:,:,:1]) # Lead I is index 0
            else:
                encoding = encoders[tensor_map_median].predict(median)
            for b in range(encoding.shape[0]):
                for i in range(encoding.shape[-1]):
                    space_dict[f'latent_{i}'].append(encoding[b,i])
            for b in range(encoding.shape[0]):
                for k in target:
                    if isinstance(target[k][b], np.float32):
                        space_dict[f'{k}'].append(target[k][b])
                    else:
                        space_dict[f'{k}'].append(target[k][b, -1])                        
        except StopIteration:
            print('loaded all batches')
            break
    return pd.DataFrame.from_dict(space_dict)
def space_from_dataloader(dataloader, merger, tensor_map, tensor_map_median, max_batches=25000):
    dataloader_iterator = iter(dataloader)
    space_dict = defaultdict(list)
    for i in range(max_batches):
        try:
            data, target = next(dataloader_iterator)
            median = ecg_10s_2_median(data[tensor_map.input_name()])
            if tensor_map_median.shape[-1] == 1:
                encoding = merger.predict(median[:,:,:1]) # Lead I is index 0
            else:
                encoding = merger.predict(median)
            for b in range(encoding.shape[0]):
                for i in range(encoding.shape[-1]):
                    space_dict[f'latent_{i}'].append(encoding[b,i])
            for b in range(encoding.shape[0]):
                for k in target:
                    if isinstance(target[k][b], np.float32):
                        space_dict[f'{k}'].append(target[k][b])
                    else:
                        space_dict[f'{k}'].append(target[k][b, -1])                        
        except StopIteration:
            print('loaded all batches')
            break
    return pd.DataFrame.from_dict(space_dict)
#df_ecg_median_ae = space_from_dataloader(dataloader, encoders, ecg_tmap, ecg_tmap_median)
df_ecg_median_ae = space_from_dataloader(dataloader, merger, ecg_tmap, ecg_tmap_median)
#df_ecg_ae_bwh = space_from_dataloader(dataloader_bwh, encoders, ecg_tmap, ecg_tmap_median)
#df_ecg_ae_lead_I = space_from_dataloader(dataloader, encoders_lead_I, ecg_tmap, ecg_tmap_median_lead_I)
#df_ecg_ae_bwh_lead_I = space_from_dataloader(dataloader_bwh, encoders_lead_I, ecg_tmap, ecg_tmap_median_lead_I)

In [None]:
df_ecg_median_ae = df_ecg_median_ae.rename(columns={'MGH_MRN_0': 'sample_id'})
model_name = 'mgh_ecg_rest_median_raw_10_autoencoder_256d_v2022_04_13'
df_ecg_median_ae.to_csv(f'../../trained_models/{model_name}/merged_mgh_latent_{model_name}.tsv', sep='\t', index=False)

In [None]:
df_ecg_median_ae.info()

In [None]:
df_ecg_ae_bwh = df_ecg_ae_bwh.rename(columns={'BWH_MRN_0': 'sample_id'})
model_name = 'mgh_ecg_rest_median_raw_10_autoencoder_256d_v2022_04_13'
df_ecg_ae_bwh.to_csv(f'../../trained_models/{model_name}/bwh_latent_{model_name}.tsv', sep='\t', index=False)

In [None]:
df_ecg_ae_lead_I = df_ecg_ae_lead_I.rename(columns={'MGH_MRN_0': 'sample_id'})
model_name = 'mgh_ecg_rest_median_raw_10_lead_I_autoencoder_256d_v2022_04_09'
df_ecg_ae_lead_I.to_csv(f'../../trained_models/{model_name}/mgh_latent_lead_I_{model_name}.tsv', sep='\t', index=False)

In [None]:
df_ecg_ae_bwh_lead_I = df_ecg_ae_bwh_lead_I.rename(columns={'BWH_MRN_0': 'sample_id'})
model_name = 'mgh_ecg_rest_median_raw_10_lead_I_autoencoder_256d_v2022_04_09'
df_ecg_ae_bwh_lead_I.to_csv(f'../../trained_models/{model_name}/bwh_latent_lead_I_{model_name}.tsv', sep='\t', index=False)

In [None]:
f'../../trained_models/{model_name}/mgh_latent_{model_name}.tsv'

In [None]:
df_ecg_ae_bwh = df_ecg_ae_bwh.rename(columns={'BWH_MRN_0': 'sample_id'})
file_name = f'../../trained_models/{model_name}/bwh_latent_{model_name}.tsv'
df_ecg_ae_bwh.to_csv(file_name, sep='\t', index=False)
print(f'Wrote latent space to: {file_name}')

In [None]:
df_ecg_ae_bwh.info()

In [None]:
df_ecg_ae.info()

In [None]:
dataloader_iterator = iter(dataloader)
data, target = next(dataloader_iterator)
median = ecg_10s_2_median(data[ecg_tmap.input_name()])
encoding = encoders[ecg_tmap_median].predict(median)
print(f'{encoding.shape}')

In [None]:
decoding = decoders[ecg_tmap_median].predict(encoding)
print(f'{decoding.shape}')

In [None]:

plt.plot(np.linspace(0, 10, 5000), data[ecg_tmap.input_name()][0])
plt.xlabel("time (s)")
plt.ylabel("amplitude (mV)")
plt.show()

In [None]:

plt.plot(np.linspace(0, 10, 600), median[2])
plt.xlabel("time (s)")
plt.ylabel("amplitude (mV)")
plt.show()

In [None]:
mo = ecg_autoencoder(median)
plt.plot(np.linspace(0, 10, 600), mo[2])
plt.xlabel("time (s)")
plt.ylabel("amplitude (mV)")
plt.show()

In [None]:
mo2 = decoders[ecg_tmap_median](merger(median))
plt.plot(np.linspace(0, 10, 600), mo2[2])
plt.xlabel("time (s)")
plt.ylabel("amplitude (mV)")
plt.show()

In [None]:
plt.plot(np.linspace(0, 10, 600), decoding[0])
plt.xlabel("time (s)")
plt.ylabel("amplitude (mV)")
plt.show()

In [None]:
df_uni = df_uni[df_uni.MGH_MRN_0.notna()]
df_uni.MGH_MRN_0 = df_uni.MGH_MRN_0.astype(np.int64)
df_uni.set_index("MGH_MRN_0", inplace=True)
df_uni.to_csv('../../mgh_drop_fuse_latent_space_uni.csv')

In [None]:
df_lvef.info()

In [None]:
df = pd.read_csv('../../mgh_drop_fuse_latent_space.csv')
df_mv = pd.read_csv('../../mgh_drop_fuse_latent_space_mv.csv')
df_uni = pd.read_csv('../../mgh_drop_fuse_latent_space_uni.csv')
#df_lvef = pd.read_csv('../../mgh_lvef_latent_space_uni.csv')
df_auto = pd.read_csv('../../mgh_auto_df.csv')



In [None]:
all_scores = {}

In [None]:
phenotypes=['output_age_from_wide_csv_continuous', 'Dem.Gender.no_filter'] + wide_cols
all_scores['zscored'] = latent_space_regression(df, phenotypes, verbose=True)
#all_scores['mgb lvef'] = latent_space_regression(df_lvef, phenotypes, num_features=33, verbose=True)
all_scores['mgb auto'] = latent_space_regression(df_auto, phenotypes, num_features=256, verbose=True)


all_scores['millivolts'] = latent_space_regression(df_mv, phenotypes, verbose=True)
all_scores['unimodal_zscored'] = latent_space_regression(df_uni, phenotypes, verbose=True)

# all_scores['per_individual_normalized2'] = latent_space_regression(df2, phenotypes, verbose=True)
# all_scores['per_individual_normalized_cw712'] = latent_space_regression(df_cw712, phenotypes, verbose=True)
# all_scores['millivolts2'] = latent_space_regression(df_mv2, phenotypes, verbose=True)

In [None]:
plot_nested_dictionary(all_scores)

In [None]:
plot_nested_dictionary(all_scores)

In [None]:
df.start_fu_LVEF.plot.hist()

In [None]:
df.start_fu_LVEF.value_counts()

In [None]:
df[['start_fu_LVEF']].info()