In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from scipy.stats import pearsonr

from torch.utils.data import DataLoader

import tensorflow as tf
from tensorflow.keras.losses import CategoricalCrossentropy

from tensorflow.keras.optimizers import Adam

from ml4h_ccds.data_descriptions.ecg import ECGDataDescription
from ml4h_ccds.data_descriptions.wide_file import WideFileDataDescription
from ml4h_ccds.data_descriptions.util import download_s3_if_not_exists

from ml4h.models.model_factory import block_make_multimodal_multitask_model
from ml4h.TensorMap import TensorMap, Interpretation

from ml4ht.data.util.date_selector import DateRangeOptionPicker, first_dt, DATE_OPTION_KEY, DateRangeOptionPicker
from ml4ht.data.data_description import DataDescription
from ml4ht.data.sample_getter import DataDescriptionSampleGetter
from ml4ht.data.explore import explore_data_descriptions, explore_sample_getter
from ml4ht.data.data_loader import SampleGetterDataset, numpy_collate_fn,SampleGetterIterableDataset

In [None]:
SESSION_DIR = os.path.expanduser("~")  # downloaded data will be stored here

# ECG data description

In [None]:
def normalize_ecg(ecg, _):
    """Transform ECG units to millivolts"""
    return ecg / 1000

def standardize_by_sample_ecg(ecg, _):
    """Transform ECG units to millivolts"""
    return (ecg - np.mean(ecg)) / (np.std(ecg) + 1e-6) 

ecg_dd_i = ECGDataDescription(
    SESSION_DIR, 
    name='input_ecg_2500_std_continuous', 
    ecg_len=2500,  # all ECGs will be linearly interpolated to be this length
    #transforms=[standardize_by_sample_ecg],  # these will be applied in order
    # data will be automatically localized from s3
    s3_bucket_name='2017P001650', s3_bucket_path=['ecg_mgh_hd5s'],   # 'ecg_mgh_hd5s',  list of hd5s
)

ecg_dd_o = ECGDataDescription(
    SESSION_DIR, 
    name='output_ecg_2500_std_continuous', 
    ecg_len=2500,  # all ECGs will be linearly interpolated to be this length
    transforms=[standardize_by_sample_ecg],  # these will be applied in order
    # data will be automatically localized from s3
    s3_bucket_name='2017P001650', s3_bucket_path=['ecg_mgh_hd5s'],   # 'ecg_mgh_hd5s',  list of hd5s
)

In [None]:
# let's test it out on one of the available dates!
sg_explore_df = pd.read_csv('../../af_survive_explore_all.csv')
working_ids = sg_explore_df[sg_explore_df["error"].isna()]["sample_id"]
sample_id = working_ids.iloc[1]
options = ecg_dd_i.get_loading_options(sample_id)
print(options)
print(working_ids[:10])
# plt.plot(np.linspace(0, 10, 5000), ecg_dd_i.get_raw_data(1, options[0]))
# plt.xlabel("time (s)")
# plt.ylabel("amplitude (mV)")
# plt.show()

In [None]:
# IPython imports
%matplotlib inline
mrn = 1519973
options = ecg_dd_i.get_loading_options(mrn)

plt.plot(np.linspace(0, 10, 2500), ecg_dd_i.get_raw_data(mrn, options[-1]))
plt.xlabel("time (s)")
plt.ylabel("amplitude (mV)")
plt.show()

example = ecg_dd_i.get_raw_data(mrn, options[-1])
example.shape

In [None]:
import h5py
with h5py.File('/home/samuel.friedman/trained_models/mgh_ecg_medians/inferred_hd5s/4282470.hd5', 'r') as hd5:
    ecg = np.array(hd5['ecg_rest_median_raw_10_prediction'])
    print(f'{ecg.shape}')
leads = ['I', 'aVR', 'V1', 'V4', 
             'II', 'aVL', 'V2', 'V5', 
             'III', 'aVF', 'V3', 'V6', ]
    
channel_map = {
    'I': 0, 'II': 1, 'III': 2, 'V1': 3, 'V2': 4, 'V3': 5,
    'V4': 6, 'V5': 7, 'V6': 8, 'aVF': 9, 'aVL': 10, 'aVR': 11,
}
fig, axes = plt.subplots(3, 4, figsize=(16, 16), dpi=300, sharey=False, sharex=True)
for i, (ax, lead) in enumerate(zip(axes.ravel(), leads)):
    ax.plot(range(600), ecg[:, channel_map[lead]])
    ax.set_title(f"Lead: {lead}")
    ax.set_xlabel("time (s)")
    ax.set_ylabel("amplitude (mV)")
plt.tight_layout()

In [None]:

import h5py
with h5py.File('/home/samuel.friedman/trained_models/mgh_ecg_medians/inferred_hd5s/5212097.hd5', 'r') as hd5:
    ecg = np.array(hd5['ecg_rest_median_raw_10_prediction'])
    print(f'{ecg.shape}')
leads = ['I', 'aVR', 'V1', 'V4', 
             'II', 'aVL', 'V2', 'V5', 
             'III', 'aVF', 'V3', 'V6', ]
    
channel_map = {
    'I': 0, 'II': 1, 'III': 2, 'V1': 3, 'V2': 4, 'V3': 5,
    'V4': 6, 'V5': 7, 'V6': 8, 'aVF': 9, 'aVL': 10, 'aVR': 11,
}
fig, axes = plt.subplots(3, 4, figsize=(16, 16), dpi=300, sharey=False, sharex=True)
for i, (ax, lead) in enumerate(zip(axes.ravel(), leads)):
    ax.plot(range(600), ecg[:, channel_map[lead]])
    ax.set_title(f"Lead: {lead}")
    ax.set_xlabel("time (s)")
    ax.set_ylabel("amplitude (mV)")
plt.tight_layout()



In [None]:
# This is how all of the components are merged together
sg = DataDescriptionSampleGetter(
    input_data_descriptions=[ecg_dd_i],  # what we want a model to use as input data
    output_data_descriptions=[ecg_dd_o],  # what we want a model to predict from the input data
)

In [None]:
sg_explore_df.head()

In [None]:
# # Let's test out the SampleGetter on a sample id
# %matplotlib inline
# import matplotlib.pyplot as plt
# sample_id = working_ids.iloc[1]
# in_data, out_data = sg(sample_id)

# ecg = in_data[ecg_dd_i.name]


# plt.plot(np.linspace(0, 10, len(ecg)), ecg)
# plt.title(f"ECG for patient")
# plt.xlabel('time (s)')
# plt.ylabel('amplitude (mV)')
# plt.show()

In [None]:
# Copy pasted from ml4h branch nd_ml4ht_integration
from ml4h.defines import PARTNERS_DATETIME_FORMAT, ECG_REST_AMP_LEADS
def _not_implemented_tensor_from_file(_, __, ___=None):
    """Used to make sure TensorMap is never used to load data"""
    raise NotImplementedError('This TensorMap cannot load data.')
    

def tensor_map_from_data_description(
        data_description: DataDescription,
        interpretation: Interpretation,
        shape,
        name=None,
        **tensor_map_kwargs,
) -> TensorMap:
    """
    Allows a DataDescription to be used in the model factory
    by converting a DataDescription into a TensorMap
    """
    tmap = TensorMap(
        name=name if name else data_description.name,
        interpretation=interpretation,
        shape=shape,
        tensor_from_file=_not_implemented_tensor_from_file,
        **tensor_map_kwargs,
    )
    return tmap


ecg_tmap = tensor_map_from_data_description(
    ecg_dd_i,
    Interpretation.CONTINUOUS,
    (2500, 12), name='ecg_2500_std',
    channel_map=ECG_REST_AMP_LEADS
)


# Building a model

In [None]:

from typing import Dict, Union, Callable
from matplotlib.figure import Figure

def unit_vector(vector):
    """ Returns the unit vector of the vector.  """
    return vector / np.linalg.norm(vector)
def _plot_partners_full(voltage: Dict[str, np.ndarray], ax: plt.Axes) -> None:
    full_voltage = np.full((12, 2500), np.nan)
    for i, lead in enumerate(voltage):
        full_voltage[i] = voltage[lead] #[::2]

    # convert voltage to millivolts
    # full_voltage /= 1000

    # calculate space between leads
    min_y, max_y = ax.get_ylim()
    y_offset = (max_y - min_y) / len(voltage)

    text_xoffset = 5
    text_yoffset = -0.01

    # plot signal and add labels
    for i, lead in enumerate(voltage):
        this_offset = (len(voltage) - i - 0.5) * y_offset
        ax.plot(full_voltage[i] + this_offset, color="black", linewidth=0.375)
        ax.text(
            0 + text_xoffset,
            this_offset + text_yoffset,
            lead,
            ha="left",
            va="top",
            weight="bold",
        )
    plt.show()
        
def _plot_partners_clinical(voltage: Dict[str, np.ndarray], ax: plt.Axes, label='black') -> None:
    # get voltage in clinical chunks
    clinical_voltage = np.full((6, 2500), np.nan)
    halfgap = 5

    clinical_voltage[0][0 : 625 - halfgap] = voltage["I"][0 : 625 - halfgap]
    clinical_voltage[0][625 + halfgap : 1250 - halfgap] = voltage["aVR"][
        625 + halfgap : 1250 - halfgap
    ]
    clinical_voltage[0][1250 + halfgap : 1875 - halfgap] = voltage["V1"][
        1250 + halfgap : 1875 - halfgap
    ]
    clinical_voltage[0][1875 + halfgap : 2500] = voltage["V4"][1875 + halfgap : 2500]

    clinical_voltage[1][0 : 625 - halfgap] = voltage["II"][0 : 625 - halfgap]
    clinical_voltage[1][625 + halfgap : 1250 - halfgap] = voltage["aVL"][
        625 + halfgap : 1250 - halfgap
    ]
    clinical_voltage[1][1250 + halfgap : 1875 - halfgap] = voltage["V2"][
        1250 + halfgap : 1875 - halfgap
    ]
    clinical_voltage[1][1875 + halfgap : 2500] = voltage["V5"][1875 + halfgap : 2500]

    clinical_voltage[2][0 : 625 - halfgap] = voltage["III"][0 : 625 - halfgap]
    clinical_voltage[2][625 + halfgap : 1250 - halfgap] = voltage["aVF"][
        625 + halfgap : 1250 - halfgap
    ]
    clinical_voltage[2][1250 + halfgap : 1875 - halfgap] = voltage["V3"][
        1250 + halfgap : 1875 - halfgap
    ]
    clinical_voltage[2][1875 + halfgap : 2500] = voltage["V6"][1875 + halfgap : 2500]

    clinical_voltage[3] = voltage["V1"]
    clinical_voltage[4] = voltage["II"]
    clinical_voltage[5] = voltage["V5"]

    voltage = clinical_voltage

    # convert voltage to millivolts
    voltage /= 1000

    # calculate space between leads
    min_y, max_y = ax.get_ylim()
    y_offset = (max_y - min_y) / len(voltage)

    text_xoffset = 5
    text_yoffset = -0.1

    # plot signal and add labels
    for i in range(len(voltage)):
        this_offset = (len(voltage) - i - 0.5) * y_offset
        if label == 'Original':
            color = 'red'
            alpha = 1.0
            linestyle='solid'
            lw =1.475
        else:
            color = 'blue'
            alpha = 0.5
            linestyle='solid'#'dashed'
            lw =1.875            
        if i == 0:
            ax.plot(voltage[i] + this_offset, label=label, color=color, alpha=alpha, linestyle=linestyle, linewidth=lw)
        else:
            ax.plot(voltage[i] + this_offset, color=color, alpha=alpha, linestyle=linestyle, linewidth=lw)
        if i == 0:
            ax.text(
                0 + text_xoffset,
                this_offset + text_yoffset,
                "I",
                ha="left",
                va="top",
                weight="bold",
            )
            ax.text(
                625 + text_xoffset,
                this_offset + text_yoffset,
                "aVR",
                ha="left",
                va="top",
                weight="bold",
            )
            ax.text(
                1250 + text_xoffset,
                this_offset + text_yoffset,
                "V1",
                ha="left",
                va="top",
                weight="bold",
            )
            ax.text(
                1875 + text_xoffset,
                this_offset + text_yoffset,
                "V4",
                ha="left",
                va="top",
                weight="bold",
            )
        elif i == 1:
            ax.text(
                0 + text_xoffset,
                this_offset + text_yoffset,
                "II",
                ha="left",
                va="top",
                weight="bold",
            )
            ax.text(
                625 + text_xoffset,
                this_offset + text_yoffset,
                "aVL",
                ha="left",
                va="top",
                weight="bold",
            )
            ax.text(
                1250 + text_xoffset,
                this_offset + text_yoffset,
                "V2",
                ha="left",
                va="top",
                weight="bold",
            )
            ax.text(
                1875 + text_xoffset,
                this_offset + text_yoffset,
                "V5",
                ha="left",
                va="top",
                weight="bold",
            )
        elif i == 2:
            ax.text(
                0 + text_xoffset,
                this_offset + text_yoffset,
                "III",
                ha="left",
                va="top",
                weight="bold",
            )
            ax.text(
                625 + text_xoffset,
                this_offset + text_yoffset,
                "aVF",
                ha="left",
                va="top",
                weight="bold",
            )
            ax.text(
                1250 + text_xoffset,
                this_offset + text_yoffset,
                "V3",
                ha="left",
                va="top",
                weight="bold",
            )
            ax.text(
                1875 + text_xoffset,
                this_offset + text_yoffset,
                "V6",
                ha="left",
                va="top",
                weight="bold",
            )
        elif i == 3:
            ax.text(
                0 + text_xoffset,
                this_offset + text_yoffset,
                "V1",
                ha="left",
                va="top",
                weight="bold",
            )
        elif i == 4:
            ax.text(
                0 + text_xoffset,
                this_offset + text_yoffset,
                "II",
                ha="left",
                va="top",
                weight="bold",
            )
        elif i == 5:
            ax.text(
                0 + text_xoffset,
                this_offset + text_yoffset,
                "V5",
                ha="left",
                va="top",
                weight="bold",
            )


def _plot_partners_figure(
    data: Dict[str, Union[np.ndarray, str, Dict]],
    plot_signal_function: Callable[[Dict[str, np.ndarray], plt.Axes], None],
    plot_mode: str,
    output_folder: str,
    run_id: str,
) -> None:
    plt.rcParams["font.family"] = "Times New Roman"
    plt.rcParams["font.size"] = 9.5

    w, h = 11, 8.5
    fig = plt.figure(
        figsize=(w, h),
        dpi=100,
    )

    # patient info and ecg text
    _plot_partners_text(data, fig, w, h)

    # define plot area in inches
    left = 0.17
    bottom = h - 7.85
    width = w - 2 * left
    height = h - bottom - 2.3

    # ecg plot area
    ax = fig.add_axes([left / w, bottom / h, width / w, height / h])

    # voltage is in microvolts
    # the entire plot area is 5.55 inches tall, 10.66 inches wide (141 mm, 271 mm)
    # the resolution on the y-axis is 10 mm/mV
    # the resolution on the x-axis is 25 mm/s
    inch2mm = lambda inches: inches * 25.4

    # 1. set y-limit to max 14.1 mV
    y_res = 10  # mm/mV
    max_y = inch2mm(height) / y_res
    min_y = 0
    ax.set_ylim(min_y, max_y)

    # 2. set x-limit to max 10.8 s, center 10 s leads
    sampling_frequency = 250  # Hz
    x_res = 25  # mm/s
    max_x = inch2mm(width) / x_res
    x_buffer = (max_x - 10) / 2
    max_x -= x_buffer
    min_x = -x_buffer
    max_x *= sampling_frequency
    min_x *= sampling_frequency
    ax.set_xlim(min_x, max_x)

    # 3. set ticks for every 0.1 mV or every 1/25 s
    y_tick = 1 / y_res
    x_tick = 1 / x_res * sampling_frequency
    x_major_ticks = np.arange(min_x, max_x, x_tick * 5)
    x_minor_ticks = np.arange(min_x, max_x, x_tick)
    y_major_ticks = np.arange(min_y, max_y, y_tick * 5)
    y_minor_ticks = np.arange(min_y, max_y, y_tick)

    ax.set_xticks(x_major_ticks)
    ax.set_xticks(x_minor_ticks, minor=True)
    ax.set_yticks(y_major_ticks)
    ax.set_yticks(y_minor_ticks, minor=True)

    ax.tick_params(
        which="both", left=False, bottom=False, labelleft=False, labelbottom=False,
    )
    ax.grid(b=True, color="r", which="major", lw=0.5)
    ax.grid(b=True, color="r", which="minor", lw=0.2)

    # signal plot
    voltage = data["2500_raw"]
    plot_signal_function(voltage, ax)

    # bottom text
    fig.text(
        0.17 / w,
        0.46 / h,
        f"{x_res}mm/s    {y_res}mm/mV    {sampling_frequency}Hz",
        ha="left",
        va="center",
        weight="bold",
    )

def _two_ecgs_clinical(
    data1: Dict[str, Union[np.ndarray, str, Dict]],
    data2: Dict[str, Union[np.ndarray, str, Dict]],
    plot_signal_function: Callable[[Dict[str, np.ndarray], plt.Axes], None],
    out_path, label1='Original', label2='Reconstruction'
):    
    plt.rcParams["font.size"] = 9.5

    w, h = 11, 8.5
    fig = Figure(
        figsize=(w, h),
        dpi=300,
    )

    # patient info and ecg text
    # _plot_partners_text(data, fig, w, h)

    # define plot area in inches
    left = 0.17
    bottom = h - 7.85
    width = w - 2 * left
    height = h - bottom - 2.3

    # ecg plot area
    ax = fig.add_axes([left / w, bottom / h, width / w, height / h])

    # voltage is in microvolts
    # the entire plot area is 5.55 inches tall, 10.66 inches wide (141 mm, 271 mm)
    # the resolution on the y-axis is 10 mm/mV
    # the resolution on the x-axis is 25 mm/s
    inch2mm = lambda inches: inches * 25.4

    # 1. set y-limit to max 14.1 mV
    y_res = 10  # mm/mV
    max_y = inch2mm(height) / y_res
    min_y = 0
    ax.set_ylim(min_y, max_y)

    # 2. set x-limit to max 10.8 s, center 10 s leads
    sampling_frequency = 250  # Hz
    x_res = 25  # mm/s
    max_x = inch2mm(width) / x_res
    x_buffer = (max_x - 10) / 2
    max_x -= x_buffer
    min_x = -x_buffer
    max_x *= sampling_frequency
    min_x *= sampling_frequency
    ax.set_xlim(min_x, max_x)

    # 3. set ticks for every 0.1 mV or every 1/25 s
    y_tick = 1 / y_res
    x_tick = 1 / x_res * sampling_frequency
    x_major_ticks = np.arange(min_x, max_x, x_tick * 5)
    x_minor_ticks = np.arange(min_x, max_x, x_tick)
    y_major_ticks = np.arange(min_y, max_y, y_tick * 5)
    y_minor_ticks = np.arange(min_y, max_y, y_tick)

    ax.set_xticks(x_major_ticks)
    ax.set_xticks(x_minor_ticks, minor=True)
    ax.set_yticks(y_major_ticks)
    ax.set_yticks(y_minor_ticks, minor=True)

    ax.tick_params(
        which="both", left=False, bottom=False, labelleft=False, labelbottom=False
    )
    ax.grid(b=True, color="r", which="major", lw=0.5)
    ax.grid(b=True, color="r", which="minor", lw=0.2)
    
    plot_signal_function(data1, ax, label=label1)
    plot_signal_function(data2, ax, label=label2)
    ax.legend()
    # bottom text
    fig.text(
        0.17 / w,
        0.46 / h,
        f"{x_res}mm/s    {y_res}mm/mV    {sampling_frequency}Hz",
        ha="left",
        va="center",
        weight="bold",
    )

    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    fig.savefig(f'{out_path}.png', format='png')
    return fig


def _plot_partners_figure(
    data: Dict[str, Union[np.ndarray, str, Dict]],
    plot_signal_function: Callable[[Dict[str, np.ndarray], plt.Axes], None],
    out_path,
):
    plt.rcParams["font.size"] = 9.5

    w, h = 11, 8.5
    fig = Figure(
        figsize=(w, h),
        dpi=100,
    )

    # patient info and ecg text
    # _plot_partners_text(data, fig, w, h)

    # define plot area in inches
    left = 0.17
    bottom = h - 7.85
    width = w - 2 * left
    height = h - bottom - 2.3

    # ecg plot area
    ax = fig.add_axes([left / w, bottom / h, width / w, height / h])

    # voltage is in microvolts
    # the entire plot area is 5.55 inches tall, 10.66 inches wide (141 mm, 271 mm)
    # the resolution on the y-axis is 10 mm/mV
    # the resolution on the x-axis is 25 mm/s
    inch2mm = lambda inches: inches * 25.4

    # 1. set y-limit to max 14.1 mV
    y_res = 10  # mm/mV
    max_y = inch2mm(height) / y_res
    min_y = 0
    ax.set_ylim(min_y, max_y)

    # 2. set x-limit to max 10.8 s, center 10 s leads
    sampling_frequency = 250  # Hz
    x_res = 25  # mm/s
    max_x = inch2mm(width) / x_res
    x_buffer = (max_x - 10) / 2
    max_x -= x_buffer
    min_x = -x_buffer
    max_x *= sampling_frequency
    min_x *= sampling_frequency
    ax.set_xlim(min_x, max_x)

    # 3. set ticks for every 0.1 mV or every 1/25 s
    y_tick = 1 / y_res
    x_tick = 1 / x_res * sampling_frequency
    x_major_ticks = np.arange(min_x, max_x, x_tick * 5)
    x_minor_ticks = np.arange(min_x, max_x, x_tick)
    y_major_ticks = np.arange(min_y, max_y, y_tick * 5)
    y_minor_ticks = np.arange(min_y, max_y, y_tick)

    ax.set_xticks(x_major_ticks)
    ax.set_xticks(x_minor_ticks, minor=True)
    ax.set_yticks(y_major_ticks)
    ax.set_yticks(y_minor_ticks, minor=True)

    ax.tick_params(
        which="both", left=False, bottom=False, labelleft=False, labelbottom=False
    )
    ax.grid(b=True, color="r", which="major", lw=0.5)
    ax.grid(b=True, color="r", which="minor", lw=0.2)

    # signal plot
    #voltage = data["2500_raw"]
    plot_signal_function(data, ax)

    # bottom text
    fig.text(
        0.17 / w,
        0.46 / h,
        f"{x_res}mm/s    {y_res}mm/mV    {sampling_frequency}Hz",
        ha="left",
        va="center",
        weight="bold",
    )

    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    fig.savefig(f'{out_path}.png', format='png')
    return fig



from IPython.display import display

def reconstruct_phecode(decoders, tensor_map, latent_df, test_df, latent_cols, 
                        phecode_file, out_path, scalar=4.0, mrn=0, label1='Control',
                        label2='Reconstruction', centroid_method=True):
    if 'PheCode' not in phecode_file:
        return
    df = pd.read_csv(phecode_file, sep='\t')
    phecode_name = f'phe_{df.iloc[0].phenotype}'.replace('.', '_').replace(' ', '')

    ratio = df.has_disease.sum() / len(df.has_disease)
    print(f'phecode_name  {phecode_name} ratio: {ratio}')
    print([c for c in df if 'latent' not in c] )   #phecode_name, 1, 0
    df = df.rename(columns={'has_disease': phecode_name})
    latent_df = pd.merge(latent_df, df, left_on='LINKER_ID', right_on='linker_id', how='inner')
    print(f'MERGED phecode_name  {phecode_name}')
    latent_df.info()
    hit = latent_df.loc[latent_df[phecode_name] >= 1]
    miss = latent_df.loc[latent_df[phecode_name] < 1]
    hit_np = hit[latent_cols].to_numpy()
    miss_np = miss[latent_cols].to_numpy()
    positroid = np.mean(hit_np, axis=0)
    negatroid = np.mean(miss_np, axis=0)
    phenotype_vector = negatroid - positroid
    
    if len(hit_np) < 100:
        return
        
    if centroid_method:
        y = decoders[tensor_map].predict(np.array([negatroid]))[0]
        cn = closest_node(positroid, hit_np)
        print(f'cn shape {hit_np[cn].shape} cn is {cn}')
        yp = decoders[tensor_map].predict(np.array([positroid]))[0]

    else:
        mencode = latent_df[latent_df.MRN == mrn][latent_cols].to_numpy()
        y = decoders[tensor_map].predict(mencode)[0]
        yp = decoders[tensor_map].predict(mencode+(scalar*phenotype_vector))[0]        

    print(f'yp is {yp.shape} \n yp is {y.shape}')    

    index2channel = {v: k for k, v in tensor_map.channel_map.items()}
    leads = ['I', 'aVR', 'V1', 'V4', 
             'II', 'aVL', 'V2', 'V5', 
             'III', 'aVF', 'V3', 'V6', ]

    width = 16
    height = 12
    fig, axes = plt.subplots(3, 4, figsize=(width, height), dpi=300, sharey=True)

    channel_map = {
    'I': 0, 'II': 1, 'III': 2, 'V1': 6, 'V2': 7, 'V3': 8,
    'V4': 9, 'V5': 10, 'V6': 11, 'aVF': 5, 'aVL': 4, 'aVR': 3,
    }
    inch2mm = lambda inches: inches * 25.4
    for j, (ax, lead) in enumerate(zip(axes.ravel(), leads)):
            # 1. set y-limit to max 14.1 mV
        max_y = 4.1
        y_res = inch2mm(height/3) / max_y 
        
        min_y = -4
        ax.set_ylim(min_y, max_y)



        # 3. set ticks for every 0.1 mV or every 1/25 s
        #y_tick = 1 / y_res
        #mms = inch2mm(width/4)
        #x_res = mms / 1.2  # mm/s
        #x_tick = 1 / x_res * 500
        x_major_ticks = np.arange(0, 601, 100)
        x_minor_ticks = np.arange(0, 601, 50)
        y_major_ticks = np.arange(min_y, max_y, 2.0)
        y_minor_ticks = np.arange(min_y, max_y, 0.1)

        ax.set_xticks(x_major_ticks)
        ax.set_xticks(x_minor_ticks, minor=True)
        ax.set_xticklabels([0,0.2,0.4,0.6,0.8,1.0,1.2])
        ax.set_yticks(y_major_ticks)
        ax.set_yticks(y_minor_ticks, minor=True)
        ax.grid(b=True, color="k", which="major", lw=0.5)
        ax.grid(b=True, color="k", which="minor", lw=0.2)
        
        ax.plot(y[:, channel_map[lead]], c="blue", lw=3, label=label1)
        ax.plot(yp[:, channel_map[lead]], c="red", lw=5, alpha=0.5, label=label2)
        ax.set_title(f"Lead: {lead}")
        #if j == 0:
        ax.legend(fontsize = 12)
        
        ax.set_xlabel('Seconds')
        ax.set_ylabel('mV')
    plt.tight_layout()
    
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    fig.savefig(f'{out_path}{phecode_name}.png', format='png')
    
    print(f'Plotted ECGs for {phecode_name} positive N: {len(hit)} negative N: {len(miss)}')
    
def closest_node(node, nodes):
    dist_2 = np.sum((nodes - node)**2, axis=1)
    return np.argsort(dist_2)[0]

def closest_node_cos(a, nodes):
    dist_2 = [np.dot(a, b)/(np.linalg.norm(a)*np.linalg.norm(b)) for b in nodes]
    return np.argmin(dist_2)

In [None]:
import sys
from ml4h.arguments import parse_args
from ml4h.metrics import coefficient_of_determination
from ml4h.explorations import latent_space_dataframe
from ml4h.models.model_factory import block_make_multimodal_multitask_model

model_name = 'mgh_ecg_rest_median_raw_10_autoencoder_256d_v2022_04_13'
model_name = 'mgh_biosppy_median_60bpm_autoencoder_256d_v2022_05_21'


sys.argv = ['train',
            '--input_tensors', 'ecg_biosppy_median_60bpm_mgb',
            '--output_tensors', 'ecg_biosppy_median_60bpm_mgb',
            '--activation', 'mish',
            '--block_size', '8', '--conv_width', '61', '--conv_layers', '64', '64', '--pool_type', 'max', 
            '--dense_blocks', '64', '64', '--dense_layers', '128', '--pool_type', 'average',
            '--learning_rate', '0.00002',
            '--encoder_blocks', 'conv_encode', '--merge_blocks', '--decoder_blocks', 'conv_decode',
           '--tensormap_prefix', 'ml4h.tensormap.ukb.ecg',
            '--model_file', f'../../trained_models/{model_name}/{model_name}.h5',
           
           ]
args = parse_args()


# ecg_tmap_median = tensor_map_from_data_description(
#     ecg_dd_i,
#     Interpretation.CONTINUOUS,
#     (600, 12), name='ecg_rest_median_raw_10',
#     channel_map=ECG_REST_AMP_LEADS
# )
# args.tensor_maps_in = [ecg_tmap_median]
# args.tensor_maps_out = [ecg_tmap_median]

ecg_autoencoder, encoders, decoders, merger = block_make_multimodal_multitask_model(**args.__dict__)

In [None]:
links = pd.read_csv('../../mrn_linker.txt', sep='\t')
# latent = pd.read_csv('mgh_drop_fuse_latent_space.csv')
# latent_df = pd.merge(latent, links, left_on='MGH_MRN_0', right_on='MRN', how='inner')
lf='/home/samuel.friedman/trained_models/mgh_biosppy_median_60bpm_autoencoder_256d_v2022_05_21/hidden_median_mgh_biosppy_median_60bpm_autoencoder_256d_v2022_05_21.tsv'


#latent = pd.read_csv('./trained_models/ecg_2500_autoencoder_mgh_c3po_128d_v2021_12_17/mgh_latent_ecg_2500_autoencoder_mgh_c3po_128d_v2021_12_17.tsv', sep='\t')
#latent = pd.read_csv('../../trained_models/mgh_ecg_rest_median_raw_10_autoencoder_256d_v2022_04_13/merged_mgh_latent_mgh_ecg_rest_median_raw_10_autoencoder_256d_v2022_04_13.tsv', sep='\t')
#latent = pd.read_csv('./trained_models/mgh_ecg_rest_median_raw_10_lead_I_autoencoder_256d_v2022_04_09/mgh_latent_lead_I_mgh_ecg_rest_median_raw_10_lead_I_autoencoder_256d_v2022_04_09.tsv', sep='\t')
latent = pd.read_csv(lf, sep='\t')



latent_df = pd.merge(latent, links, left_on='sample_id', right_on='MRN', how='inner')
latent_dimension = 256
latent_cols = [f'latent_{i}' for i in range(latent_dimension)]
latent_df.info()

In [None]:
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
                                    phecode_file='../../phecodes/PheCode_805.txt',
                                    out_path='../../ecg_phecode_reconstruct/', 
                    label1='Control',               
                    label2='Right Bundle Branch Block', centroid_method=True)

In [None]:
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
                                    phecode_file='../../phecodes/PheCode_806.txt',
                                    out_path='../../ecg_phecode_reconstruct/', 
                                    label2='Left Bundle Branch Block', centroid_method=True)

In [None]:
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
                                    phecode_file='../../phecodes/PheCode_388.txt',
                                    out_path='../../ecg_phecode_reconstruct/', 
                                    label2='Obesity', centroid_method=True)

In [None]:
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file='../../phecodes/PheCode_793.txt',
                    out_path='../../ecg_phecode_reconstruct/', 
                   label2='Hypertrophic Cardiomyopathy', centroid_method=True)

In [None]:
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file='../../phecodes/PheCode_793.txt',
                    out_path='../../ecg_phecode_reconstruct/', 
                   label2='Hypertrophic Cardiomyopathy', centroid_method=True)

In [None]:
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["496_3"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Bronchiectasis', centroid_method=True)

reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["250"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Diabetes mellitus', centroid_method=True)
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["272_1"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Hyperlipidemia', centroid_method=True)
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["274_1"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Gout', centroid_method=True)
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["585"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Renal failure', centroid_method=True)
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["600"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Hyperplasia of prostate', centroid_method=True)
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["574"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Cholelithiasis and cholecystitis', centroid_method=True)
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["562"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Diverticulosis and diverticulitis', centroid_method=True)
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["327_3"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Sleep apnea', centroid_method=True)
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["345"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Epilepsy, recurrent seizures, convulsions', centroid_method=True)
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["331_9"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Cerebral degeneration, unspecified', centroid_method=True)
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["316"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Substance addiction and disorders', centroid_method=True)
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["285"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Other anemias', centroid_method=True)
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols,
                    phecode_file=phe2file["286"],
                    out_path='../../ecg_phecode_reconstruct/',
                    label2='Coagulation defects', centroid_method=True)



In [None]:
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
                                    phecode_file='../../phecodes/PheCode_370.txt',
                                    out_path='../../ecg_phecode_reconstruct/', 
                   label2='Hyperpotassemia', centroid_method=True)

In [None]:
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
                                    phecode_file='../../phecodes/PheCode_371.txt',
                                    out_path='../../ecg_phecode_reconstruct/', 
                   label2='Hypopotassemia')

In [None]:
reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
                                    phecode_file='../../phecodes/PheCode_798.txt',
                                    out_path='../../ecg_phecode_reconstruct/',
                   label2='AV block')

In [None]:
# reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
#                                     phecode_file='../../phecodes/PheCode_388.txt',
#                                     out_path='../../ecg_phecode_reconstruct/obesity', 
#                     scalar=5.0, mrn=3467344, label2='Obesity Vector Transformation', centroid_method=True)

# reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
#                                     phecode_file='../../phecodes/PheCode_388.txt',
#                                     out_path='../../ecg_phecode_reconstruct/obesity', 
#                     scalar=5.0, mrn=1130296, label2='Obesity Vector Transformation')

reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
                                    phecode_file='../../phecodes/PheCode_806.txt',
                                    out_path='../../ecg_phecode_reconstruct/lbbb', 
                    scalar=2.0, 
                    mrn=3467344,
                   label2='Left Bundle Branch Block')

# reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
#                                     phecode_file='../../phecodes/PheCode_806.txt',
#                                     out_path='../../ecg_phecode_reconstruct/lbbb', 
#                     scalar=2.0, mrn=1130296,
#                    label2='Left Bundle Branch Block')

# reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
#                                     phecode_file='../../phecodes/PheCode_1182.txt',
#                                     out_path='../../ecg_phecode_reconstruct/arf', 
#                     scalar=3.0, mrn=3467344, label2='Acute Renal Failure')


# reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
#                                     phecode_file='../../phecodes/PheCode_798.txt',
#                                     out_path='../../ecg_phecode_reconstruct/lbbb', 
#                     scalar=1.0, mrn=3467344,
#                    label2='AV block')

# reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
#                                     phecode_file='../../phecodes/PheCode_793.txt',
#                                     out_path='../../ecg_phecode_reconstruct/lbbb', 
#                     scalar=4.0, mrn=1130296,
#                    label2='HCM', centroid_method=True)

# reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
#                                     phecode_file='../../phecodes/PheCode_955.txt',
#                                     out_path='../../ecg_phecode_reconstruct/lbbb', 
#                     scalar=2.0, mrn=1130296,
#                    label2='Chronic Bronchitis')



# reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
#                                     phecode_file='../../phecodes/PheCode_1053.txt',
#                                     out_path='../../ecg_phecode_reconstruct/lbbb', 
#                     scalar=5.0, mrn=3467344,
#                    label2='GERD')

# reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
#                                     phecode_file='../../phecodes/PheCode_545.txt',
#                                     out_path='../../ecg_phecode_reconstruct/lbbb', 
#                     scalar=4.0, mrn=1130296,
#                    label2='Obstructive Sleep Apnea')

# reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
#                                     phecode_file='../../phecodes/PheCode_370.txt',
#                                     out_path='../../ecg_phecode_reconstruct/lbbb', 
#                     scalar=2.0, mrn=1130296,
#                    label2='Hyperpotassemia')

# reconstruct_phecode(decoders, args.tensor_maps_in[0], latent_df, latent_df, latent_cols, 
#                                     phecode_file='../../phecodes/PheCode_371.txt',
#                                     out_path='../../ecg_phecode_reconstruct/lbbb', 
#                     scalar=12.0, mrn=1130296,
#                    label2='Hypopotassemia')

In [None]:
import os
import h5py
def plot_phecode(phecode_file,tensor_map, n=12,
                 label1='Normal',
                    label2='Reconstruction',
                 hd5s='/home/samuel.friedman/trained_models/mgh_ecg_medians/inferred_hd5s/'):
    if 'PheCode' not in phecode_file:
        return
    df = pd.read_csv(phecode_file, sep='\t')
    phecode_name = f'phe_{df.iloc[0].phenotype}'.replace('.', '_').replace(' ', '')
    ratio = df.has_disease.sum() / len(df.has_disease)
    links = pd.read_csv('../../mrn_linker.txt', sep='\t')
    df = pd.merge(df, links, left_on='linker_id', right_on='LINKER_ID', how='inner')
    df = df.rename(columns={'has_disease': phecode_name})

    hit = df.loc[df[phecode_name] >= 1]
    miss = df.loc[df[phecode_name] < 1]
    i = 0
    ecgs = []
    for _,row in hit.iterrows():
        f = os.path.join(hd5s, f'{row.MRN}.hd5')
        if os.path.exists(f):
            with h5py.File(f, 'r') as hd5:
                ecgs.append(np.array(hd5['ecg_rest_median_raw_10_prediction'], dtype=np.float32))
                print(f'MRN chosen {row.MRN}')
            i += 1
            if i >= n:
                break
    ecgs = np.array(ecgs)
    yp = np.median(ecgs, axis=0)
    
    i = 0
    ecgs = []
    for _,row in miss.iterrows():
        f = os.path.join(hd5s, f'{row.MRN}.hd5')
        if os.path.exists(f):
            with h5py.File(f, 'r') as hd5:
                ecgs.append(np.array(hd5['ecg_rest_median_raw_10_prediction'], dtype=np.float32))
            i += 1
            if i >= n:
                break
    ecgs = np.array(ecgs)
    y = np.median(ecgs, axis=0)
    
    index2channel = {v: k for k, v in tensor_map.channel_map.items()}
    leads = ['I', 'aVR', 'V1', 'V4', 
             'II', 'aVL', 'V2', 'V5', 
             'III', 'aVF', 'V3', 'V6', ]
    
    fig, axes = plt.subplots(3, 4, figsize=(18, 12), dpi=300, sharey=True)
    channel_map = {
    'I': 0, 'II': 1, 'III': 2, 'V1': 3, 'V2': 4, 'V3': 5,
    'V4': 6, 'V5': 7, 'V6': 8, 'aVF': 9, 'aVL': 10, 'aVR': 11,
}
    for j, (ax, lead) in enumerate(zip(axes.ravel(), leads)):
        ax.plot(y[:, channel_map[lead]], c="blue", label=label1)
        ax.plot(yp[:, channel_map[lead]], c="red", lw=2, alpha=0.7, label=label2)
        ax.set_title(f"Lead: {lead}")
        ax.legend()
        #axes[j, 1].legend()
        ax.legend()
    plt.tight_layout()
    #plt.show()   
    print(f'Plotted ECGs for {phecode_name} positive N: {len(hit)} negative N: {len(miss)}')
    

In [None]:
#plot_phecode('../../phecodes/PheCode_388.txt', args.tensor_maps_in[0], label2='Obesity')
plot_phecode('../../phecodes/PheCode_806.txt', args.tensor_maps_in[0], label2='Left Bundle Branch Block')
# plot_phecode('../../phecodes/PheCode_1182.txt', args.tensor_maps_in[0], label2='Acute Renal Failure')
# plot_phecode('../../phecodes/PheCode_798.txt', args.tensor_maps_in[0], label2='AV Block')
# plot_phecode('../../phecodes/PheCode_793.txt', args.tensor_maps_in[0], label2='Hypertrophic obstructive cardiomyopathy')

# plot_phecode('../../phecodes/PheCode_955.txt', args.tensor_maps_in[0], label2='Chronic Bronchitis')

# plot_phecode('../../phecodes/PheCode_545.txt', args.tensor_maps_in[0], label2='Obstructive Sleep Apnea')



In [None]:
links = pd.read_csv('../../mrn_linker.txt', sep='\t')
links.info()

In [None]:
from scipy import stats
import seaborn as sb
def write_phecode(latent_df, phecode_file, out_path):
    if 'PheCode' not in phecode_file:
        return
    df = pd.read_csv(phecode_file, sep='\t')
    phecode_name = f'phe_{df.iloc[0].phenotype}'.replace('.', '_').replace(' ', '')
    #print(f'phe file {phecode_file} has name: {phecode_name} cases: {df.has_disease.sum()}')
    ratio = df.has_disease.sum() / len(df.has_disease)
    print(f'phecode_name  {phecode_name} ratio: {ratio} len phe: {len(df)}')
    print([c for c in df if 'latent' not in c] )   #phecode_name, 1, 0
    df = df.rename(columns={'has_disease': phecode_name})
    latent_df = pd.merge(latent_df, df, left_on='LINKER_ID', right_on='linker_id', how='inner')
    print(f'MERGED phecode_name  {phecode_name}')
    latent_df.info()
    latent_df[[c for c in latent_df if 'latent' not in c]].to_csv(f'{out_path}.csv', index=False)

     

    
def histogram_phecode(latent_df, phecode_file, phecode_title, out_path, random_permute=False):
    if 'PheCode' not in phecode_file:
        return
    df = pd.read_csv(phecode_file, sep='\t')
    phecode_name = f'phe_{df.iloc[0].phenotype}'.replace('.', '_').replace(' ', '')
    #print(f'phe file {phecode_file} has name: {phecode_name} cases: {df.has_disease.sum()}')
    ratio = df.has_disease.sum() / len(df.has_disease)
    print(f'phecode_name  {phecode_name} ratio: {ratio} len phe: {len(df)}')
    print([c for c in df if 'latent' not in c] )   #phecode_name, 1, 0
    df = df.rename(columns={'has_disease': phecode_name})
    latent_df = pd.merge(latent_df, df, left_on='LINKER_ID', right_on='linker_id', how='inner')
    print(f'MERGED phecode_name  {phecode_name}')
    latent_df.info()
    hit = latent_df.loc[latent_df[phecode_name] >= 1]
    miss = latent_df.loc[latent_df[phecode_name] < 1]
    if random_permute:
        hit = latent_df.sample(n=len(hit))
        miss = latent_df.drop(hit.index)
    hit_np = hit[latent_cols].to_numpy()
    miss_np = miss[latent_cols].to_numpy()

    phenotype_vector = unit_vector(positroid - negatroid)
    space = latent_df[latent_cols].to_numpy()
#     hit_np -= np.mean(space)
#     hit_np /= np.std(space)
#     miss_np -= np.mean(space)
#     miss_np /= np.std(space)  
    
    fig, ax = plt.subplots(figsize=(8, 4), dpi=300)
    hit_dots = np.array([np.dot(phenotype_vector, v) for v in hit_np])
    miss_dots = np.array([np.dot(phenotype_vector, v) for v in miss_np])
    dists = [list(hit_dots), list(miss_dots)]
    if random_permute:
        labels = [f'{phecode_title} Present, n={len(hit_dots)}', 
                  f'{phecode_title} Absent, n={len(miss_dots)}']
        ax.set_xlabel(f'Position along random vector')
    else:
        labels = [f'{phecode_title} Present, n={len(hit_dots)}', f'{phecode_title} Absent, n={len(miss_dots)}']
        ax.set_xlabel(f'Position along {phecode_title} vector')
    
    t_stat, p_val = stats.ttest_ind(hit_dots, miss_dots, equal_var = False)
    for i, data in enumerate(dists):
        #plt.hist(data, bins = 40, label=labels[i], alpha=0.5, density=True)
        sb.kdeplot(np.array(data), bw=0.25, label=labels[i], ax=ax)
        # Title and labels
        ax.set_title(f'{phecode_title}')# T-Test P-Value {p_val:0.2E}')
        
        ax.set_ylabel('Normalized density of distribution')
        ax.legend()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.show()

def write_phecode_centroid(latent_df, phecode_file, out_path):
    if 'PheCode' not in phecode_file:
        return
    df = pd.read_csv(phecode_file, sep='\t')
    phecode_name = f'phe_{df.iloc[0].phenotype}'.replace('.', '_').replace(' ', '')
    #print(f'phe file {phecode_file} has name: {phecode_name} cases: {df.has_disease.sum()}')
    ratio = df.has_disease.sum() / len(df.has_disease)
    print(f'phecode_name  {phecode_name} ratio: {ratio} len phe: {len(df)}')
    #print([c for c in df if 'latent' not in c] )   #phecode_name, 1, 0
    df = df.rename(columns={'has_disease': phecode_name})
    latent_df = pd.merge(latent_df, df, left_on='LINKER_ID', right_on='linker_id', how='inner')
    #print(f'MERGED phecode_name  {phecode_name}')
    #latent_df.info()
    hit = latent_df.loc[latent_df[phecode_name] >= 1]
    miss = latent_df.loc[latent_df[phecode_name] < 1]
    hit_np = hit[latent_cols].to_numpy()
    miss_np = miss[latent_cols].to_numpy()
    space = latent_df[latent_cols].to_numpy()
    space -= np.mean(space)
    space /= np.std(space)
    hit_np -= np.mean(space)
    hit_np /= np.std(space)
    miss_np -= np.mean(space)
    miss_np /= np.std(space)
    print(f'len(hit): {len(hit)} len(miss): {len(miss)}')
    positroid = np.mean(hit_np, axis=0)
    negatroid = np.mean(miss_np, axis=0)
    phenotype_vector = positroid - negatroid
    print(f'pearsonr: {pearsonr(positroid, negatroid)}')
    #latent_df['dot'] = np.dot(latent_df[latent_cols].to_numpy(), phenotype_vector)
    #latent_df = latent_df.sort_values(by=f'dot', ascending=False)
    #latent_df[f'{phecode_name}_percentile'] = (latent_df[f'dot'].argsort().astype(np.float32)/len(latent_df))*100

    dist = (space - positroid)**2
    #print(f'dist: {dist.shape}')
    dist = np.sum(dist, axis=1)
    #print(f'dist: {dist.shape}')
    dist = np.sqrt(dist)
    #print(f'dist: {dist.shape}')
    latent_df[f'{phecode_name}_positroid_distance'] = dist
    latent_df = latent_df.sort_values(by=f'{phecode_name}_positroid_distance')
    latent_df[f'{phecode_name}_positroid_percentile'] = (latent_df[f'{phecode_name}_positroid_distance'].argsort().astype(np.float32)/len(latent_df))*100
    
#     d2 = (space - negatroid)**2
#     d2 = np.sum(d2, axis=1)
#     d2 = np.sqrt(d2)
#     latent_df[f'{phecode_name}_negatroid_distance'] = d2
    print(f"Distance pearsonr: {pearsonr(latent_df[f'{phecode_name}_positroid_distance'], latent_df[f'{phecode_name}_positroid_distance'])}")
    #latent_df[f'{phecode_name}_negatroid_percentile'] = (latent_df[f'{phecode_name}_negatroid_distance'].argsort().astype(np.float32)/len(latent_df))*100
    latent_df[[c for c in latent_df if 'latent' not in c]].to_csv(f'{out_path}.csv', index=False)
    return latent_df
    

In [None]:
df = write_phecode_centroid(latent_df, phecode_file='../../phecodes/PheCode_806.txt',
              out_path='../../phecode_centroids/lbbb_phecode_426_32_new2')

In [None]:
import seaborn as sb
histogram_phecode(latent_df, phecode_file='../../phecodes/PheCode_388.txt', phecode_title='Obesity',
              out_path='../../mgh_Obesity_phecode')

In [None]:
histogram_phecode(latent_df, phecode_file='../../phecodes/PheCode_806.txt', phecode_title='Left Bundle Branch Block',
              out_path='../../mgh_lbbb_phecode')

In [None]:
write_phecode(latent_df, phecode_file='../../phecodes/PheCode_388.txt',
              out_path='../../mgh_obesity_phecode')

In [None]:
phe2file = {}
test_phe_folder = '../../phecodes_bwh/'
for phe_file in sorted(os.listdir(test_phe_folder)):
    df = pd.read_csv(test_phe_folder + phe_file, sep='\t')
    phecode_name = f'{df.iloc[0].phenotype}'.replace('.', '_').replace(' ', '')
    phe2file[phecode_name] = f'../../phecodes/{phe_file}'
    print(f'phe_file  {phe_file}   {phecode_name}' )

In [None]:
phe_list=[ '401', '426_32',
'571',
'572',
          '394_1',
'425',
'428_1',
'411_2',
          '411_4',
'427',
 '427_2',
          '427_12',
          '428_4',
'425_11',
'496',
'509',
'495',
          '442_1',
'502',
'496_3',
'476',
'512',
'585',
'580_1',
'600',
          '496_3', '250', '272_1', '274_1', '585', '600', '574', '562', '327_3', '345', '331_9', '316', '285', '286'

]

In [None]:
write_phecode_centroid(latent_df, phecode_file=phe2file[code], 
                           out_path=f'../../phecode_centroids_new/mgh_phecode_{code}')

In [None]:
for code in phe_list:
    write_phecode_centroid(latent_df, phecode_file=phe2file[code], 
                           out_path=f'../../phecode_centroids_v2022_11_11/mgh_phecode_{code}')

In [None]:
for code in phe2file:
    write_phecode_centroid(latent_df, phecode_file=phe2file[code], 
                           out_path=f'../../phecode_centroids/mgh_phecode_{code}')

In [None]:
df = write_phecode_centroid(latent_df, phecode_file='../../phecodes/PheCode_806.txt',
              out_path='../../phecode_centroids/lbbb_phecode_426_32_new2')

In [None]:

#pearsonr(df.phe_426_32_positroid_distance, df.phe_426_32_negatroid_distance)

In [None]:
write_phecode(latent_df, phecode_file='../../phecodes/PheCode_388.txt',
              out_path='../../mgh_obesity_phecode_278.1')

In [None]:
write_phecode(latent_df, phecode_file='../../phecodes/PheCode_830.txt',
              out_path='../../mgh_congestive_hf_phecode_428')

In [None]:
write_phecode(latent_df, phecode_file='../../phecodes/PheCode_806.txt',
              out_path='../../mgh_lbbb_phecode_426.32')

In [None]:
write_phecode(latent_df, phecode_file='../../phecodes/PheCode_760.txt',
              out_path='../../mgh_hypertension_phecode_401')