In [3]:
from ptbxlae.dataprocessing.dataModules import SingleCycleCachedDM
from ptbxlae.modeling.singleCycleConv import SingleCycleConvVAE

import pandas as pd
from tqdm.auto import tqdm
import os
import torch
import numpy as np

pd.options.mode.chained_assignment = None  # default='warn'

# Get Latent Representations

In [4]:
dm = SingleCycleCachedDM(cache_folder="../cache/singlecycle_data")
dm.setup(stage="test")
metadata = dm.test_ds.dataset.metadata

torch.no_grad()
m = SingleCycleConvVAE.load_from_checkpoint('../cache/savedmodels/last-v9.ckpt').eval()
m.cpu()

latent_dicts = list()

collected_latents = list()

for test_index in tqdm(dm.test_ds.indices):
    pid = dm.test_ds.dataset.patient_ids[test_index]
    patient_dir = f"../cache/singlecycle_data/{pid}"

    # For purposes of testing, only consider first ecg in patient directory
    ecg_id = os.listdir(patient_dir)[0]
    ecg_dir = f"{patient_dir}/{ecg_id}"
    cycles = os.listdir(ecg_dir)

    batched_cycles = np.stack([pd.read_parquet(f"{ecg_dir}/{c}").to_numpy().transpose() for c in cycles])

    latent_representations = m.encode(torch.Tensor(batched_cycles)).mean(dim=0).detach()
    
    labeled_series = pd.Series(data=latent_representations, index=[f'latent_{x}' for x in range(0, m.latent_dim)])
    labeled_series['ecg_id'] = int(ecg_id)
    # labeled_series['patient_id'] = int(pid)
    
    collected_latents.append(labeled_series)


latent_df = pd.concat(collected_latents, axis=1).T.set_index('ecg_id')
latent_df

  0%|          | 0/1886 [00:00<?, ?it/s]

Unnamed: 0_level_0,latent_0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,...,latent_30,latent_31,latent_32,latent_33,latent_34,latent_35,latent_36,latent_37,latent_38,latent_39
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
19196.0,-0.403326,0.250151,-0.356406,-0.546642,-0.005614,-0.048089,0.493035,0.279139,0.194466,0.054529,...,0.244233,-0.013454,0.688355,-0.364839,0.454960,-0.539895,-0.118311,-0.257853,0.447013,0.689100
3668.0,-0.292931,-0.037926,0.527894,-0.197620,-0.242009,-0.182316,-0.166472,-0.056521,-0.157053,-0.120868,...,0.336133,0.338843,-1.050666,-0.260718,0.659946,0.067202,0.264377,0.275234,-0.144624,-0.303861
14208.0,0.391137,-0.257465,0.030430,0.315410,-0.080351,-0.344115,-0.034820,-0.060115,-0.129054,-0.223707,...,-0.101984,-0.064872,-0.617020,0.119079,-0.118057,-0.001525,0.188477,0.141371,1.878248,-0.683605
11998.0,-0.260490,-0.211127,0.322183,0.015873,-0.519412,0.070331,1.285933,-0.195377,0.417739,-1.078158,...,-0.219120,-0.216730,-0.213710,-0.040173,-0.194736,-0.178203,0.090688,0.381709,-0.242529,0.052120
20383.0,0.151117,0.637537,-0.678529,0.291237,0.234871,-0.204321,0.266440,-0.074526,0.035427,0.248537,...,0.242050,0.352915,0.105273,-0.133592,-0.195522,-0.366279,-0.252623,-0.668661,-0.080612,0.171629
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10243.0,-0.066615,-0.165400,0.607334,0.430423,0.123390,0.277408,-0.039967,0.267652,0.012030,-0.525706,...,0.266091,0.242707,-0.924728,0.178040,0.309326,-0.091778,0.696544,-0.572375,-0.682747,-1.087925
18045.0,-0.728071,0.798129,-0.317219,0.314455,-0.588004,0.517043,0.366484,-0.400036,-0.048394,-0.723383,...,0.316432,-0.753666,-0.584291,0.066964,0.271190,0.660915,0.520717,0.523755,-0.476911,-1.274686
16081.0,0.163421,-0.383719,0.232284,0.143286,0.336172,-0.056222,0.218722,-0.169651,-0.017852,0.382086,...,0.873949,-0.260048,0.678603,-0.034526,0.145495,-0.356723,0.549067,0.236102,-0.190113,-0.753694
15854.0,-0.340427,-0.473701,0.562393,-0.147118,0.066251,0.426172,-0.025563,-0.928504,-0.300426,-0.310384,...,0.191779,-0.941156,0.532365,-0.103848,0.032864,0.180290,-0.304459,-0.111411,-1.348850,-0.222610


# Assess Predictive Power of Latent Representations for Each Diagnostic Label in PTB

In [5]:
metadata = dm.test_ds.dataset.metadata
combined_df = pd.merge(latent_df, metadata, how='left', left_index=True, right_index=True)
combined_df

Unnamed: 0_level_0,latent_0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,...,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
19196.0,-0.403326,0.250151,-0.356406,-0.546642,-0.005614,-0.048089,0.493035,0.279139,0.194466,0.054529,...,True,,,,,,,9,records100/19000/19196_lr,records500/19000/19196_hr
3668.0,-0.292931,-0.037926,0.527894,-0.197620,-0.242009,-0.182316,-0.166472,-0.056521,-0.157053,-0.120868,...,True,,,,,,,8,records100/03000/03668_lr,records500/03000/03668_hr
14208.0,0.391137,-0.257465,0.030430,0.315410,-0.080351,-0.344115,-0.034820,-0.060115,-0.129054,-0.223707,...,True,,,,,,,7,records100/14000/14208_lr,records500/14000/14208_hr
11998.0,-0.260490,-0.211127,0.322183,0.015873,-0.519412,0.070331,1.285933,-0.195377,0.417739,-1.078158,...,True,,,,,,,1,records100/11000/11998_lr,records500/11000/11998_hr
20383.0,0.151117,0.637537,-0.678529,0.291237,0.234871,-0.204321,0.266440,-0.074526,0.035427,0.248537,...,True,,,,,,,5,records100/20000/20383_lr,records500/20000/20383_hr
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10243.0,-0.066615,-0.165400,0.607334,0.430423,0.123390,0.277408,-0.039967,0.267652,0.012030,-0.525706,...,True,,,,,,,4,records100/10000/10243_lr,records500/10000/10243_hr
18045.0,-0.728071,0.798129,-0.317219,0.314455,-0.588004,0.517043,0.366484,-0.400036,-0.048394,-0.723383,...,False,,,,,,,6,records100/18000/18045_lr,records500/18000/18045_hr
16081.0,0.163421,-0.383719,0.232284,0.143286,0.336172,-0.056222,0.218722,-0.169651,-0.017852,0.382086,...,True,,,,,,,6,records100/16000/16081_lr,records500/16000/16081_hr
15854.0,-0.340427,-0.473701,0.562393,-0.147118,0.066251,0.426172,-0.025563,-0.928504,-0.300426,-0.310384,...,True,,,,,,,3,records100/15000/15854_lr,records500/15000/15854_hr


In [6]:
import ast
from util import eval_predictive_power_binary_outcome

all_scps = pd.read_csv("../data/scp_statements.csv", index_col=0)

def ptb_val_to_pseudobinary_label(scp_code_of_interest: str, this_recording_scp_codes: str):
    if scp_code_of_interest not in this_recording_scp_codes.keys():
        return 0.0
    elif scp_code_of_interest in this_recording_scp_codes.keys() and this_recording_scp_codes[scp_code_of_interest] == 100.0:
        return 1.0
    else:
        return float('nan')


results = list()
for scp_code in tqdm(all_scps.index.to_list()):
    combined_df[f'scp.{scp_code}'] = combined_df['scp_codes'].apply(lambda codes: ptb_val_to_pseudobinary_label(scp_code, ast.literal_eval(codes)))
    relevant_df = combined_df[~combined_df[f'scp.{scp_code}'].isna()]

    res = eval_predictive_power_binary_outcome(relevant_df[latent_df.columns], relevant_df[f'scp.{scp_code}'])
    res['Target'] = all_scps.loc[scp_code]['description']
    results.append(res)


results_df = pd.DataFrame.from_records(results)
results_df.nlargest(n=50, columns=['Avg CV score'])


  0%|          | 0/71 [00:00<?, ?it/s]

Unnamed: 0,Total usable,% positive,Avg CV score,Target
15,1886,0.027041,0.996884,complete left bundle branch block
14,1886,0.020148,0.981719,complete right bundle branch block
8,1884,0.064756,0.962992,left anterior fascicular block
7,1810,0.049724,0.951034,left ventricular hypertrophy
9,1874,0.040555,0.93274,non-specific ischemic
21,1886,0.009544,0.925821,subendocardial injury in anteroseptal leads
19,1877,0.008524,0.92168,anterolateral myocardial infarction
25,1883,0.005311,0.912423,ischemic in anteroseptal leads
4,1674,0.419355,0.905141,normal ECG
6,1821,0.065349,0.900613,anteroseptal myocardial infarction
