In [1]:
from ptbxlae.dataprocessing.dataModules import SingleCycleCachedDM
from ptbxlae.modeling.convolutionalVAE import ConvolutionalEcgVAE

import pandas as pd
from tqdm.auto import tqdm
import os
import torch
import numpy as np

pd.options.mode.chained_assignment = None  # default='warn'

* PTB-XL Autoencoder *


# Get Latent Representations

In [2]:
dm = SingleCycleCachedDM(cache_folder="../cache/singlecycle_data")
dm.setup(stage="test")
metadata = dm.test_ds.dataset.metadata

torch.set_grad_enabled(False)
m = ConvolutionalEcgVAE.load_from_checkpoint('../cache/archivedmodels/scc-epoch=081-val_loss=332.327423.ckpt').eval()
m.cpu()
m

ConvolutionalEcgVAE(
  (encoder): ConvolutionalEcgEncoder(
    (net): Sequential(
      (0): Conv1d(12, 24, kernel_size=(13,), stride=(2,), padding=(6,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(24, 48, kernel_size=(13,), stride=(2,), padding=(6,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Flatten(start_dim=1, end_dim=-1)
      (5): Linear(in_features=6000, out_features=1500, bias=True)
      (6): LeakyReLU(negative_slope=0.01)
    )
  )
  (decoder): ConvolutionalEcgDecoder(
    (net): Sequential(
      (0): Linear(in_features=40, out_features=1500, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=1500, out_features=6000, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Unflatten(dim=1, unflattened_size=(48, 125))
      (5): ConvTranspose1d(48, 24, kernel_size=(13,), stride=(2,), padding=(6,), output_padding=(1,))
      (6): LeakyReLU(negative_slope=0.01)
      (7): ConvTranspose1d(24, 12, kernel_size=(13,),

In [3]:

latent_dicts = list()

collected_latents = list()

for test_index in tqdm(dm.test_ds.indices):
    pid = dm.test_ds.dataset.patient_ids[test_index]
    patient_dir = f"../cache/singlecycle_data/{pid}"

    # For purposes of testing, only consider first ecg in patient directory
    ecg_id = os.listdir(patient_dir)[0]
    ecg_dir = f"{patient_dir}/{ecg_id}"
    cycles = os.listdir(ecg_dir)

    batched_cycles = np.stack([pd.read_parquet(f"{ecg_dir}/{c}").to_numpy().transpose() for c in cycles])

    latent_representations = m.encode(torch.Tensor(batched_cycles)).mean(dim=0)
    
    labeled_series = pd.Series(data=latent_representations, index=[f'latent_{x}' for x in range(0, m.encoder.architecture_params.latent_dim)])
    labeled_series['ecg_id'] = int(ecg_id)
    # labeled_series['patient_id'] = int(pid)
    
    collected_latents.append(labeled_series)


latent_df = pd.concat(collected_latents, axis=1).T
latent_df['ecg_id'] = latent_df['ecg_id'].astype(int)
latent_df = latent_df.set_index('ecg_id')
latent_df

  0%|          | 0/1886 [00:00<?, ?it/s]

Unnamed: 0_level_0,latent_0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,...,latent_30,latent_31,latent_32,latent_33,latent_34,latent_35,latent_36,latent_37,latent_38,latent_39
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
19196,-0.303843,0.360902,-0.228601,0.335436,-0.271603,-0.093063,-0.340502,0.174494,-0.460498,-0.284710,...,-0.169714,0.668969,0.089786,0.467682,-0.025359,-0.553428,-0.103301,-0.022508,0.560380,0.020926
3668,-0.273493,0.290129,-0.691656,-0.049703,-0.177563,0.215130,-0.264510,0.139445,0.005577,0.629999,...,-0.221462,-0.727805,-0.444252,-0.326195,0.228936,-0.026667,-0.015919,-0.213371,-0.007092,0.093512
14208,0.218258,-0.081778,0.528267,-0.077789,-0.442081,-0.000252,0.624470,0.127420,-0.055355,-0.250661,...,-0.109791,0.209794,-0.281170,0.417842,0.172017,-0.045404,-0.208351,-0.274905,-0.105042,0.199410
11998,0.167583,-0.053707,0.181361,0.105666,0.006340,-0.956509,-0.235292,-0.107994,0.182667,-0.821843,...,-0.312430,-0.095533,-0.059998,-0.379110,-0.010326,0.435635,0.945761,0.497419,-0.112318,0.185189
20383,0.041010,-0.544289,-0.131169,-0.000858,-0.416247,0.216033,0.439566,-0.140734,0.240704,-0.148970,...,-0.089692,-0.864767,0.167701,0.050081,0.055693,-0.014531,-0.179791,0.033570,0.312077,-0.349808
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10243,-1.391763,-0.603890,0.254692,0.121041,0.112623,-0.047975,0.194222,-0.146493,-0.223625,-0.204178,...,-0.383662,-1.224733,-0.158652,0.265646,-0.703461,0.692361,-0.245070,0.838875,-0.169564,-0.331857
18045,-0.498238,-0.261414,0.157817,0.331121,0.256370,-0.540003,0.125904,-0.071753,1.491374,-0.258937,...,-0.153775,-1.255663,0.331679,0.474884,-0.234478,0.353700,0.450454,0.337069,0.026088,0.445704
16081,0.290756,-0.431782,0.446510,0.347812,0.150811,0.238983,0.106061,-0.257096,0.276724,0.169287,...,-0.205710,-0.093902,0.092576,0.235409,0.000508,0.560363,0.071305,0.390164,0.077860,-0.347680
15854,-0.527246,0.054852,0.100475,0.506417,0.200829,0.545346,-0.345091,-0.611054,-0.205950,-0.043287,...,-0.607084,0.374571,-0.291290,-0.662856,0.265012,0.732282,0.415672,0.382822,0.178055,-0.124122


# Assess Predictive Power of Latent Representations for Each Diagnostic Label in PTB

In [4]:
metadata = dm.test_ds.dataset.metadata
combined_df = pd.merge(latent_df, metadata, how='left', left_index=True, right_index=True)
combined_df

Unnamed: 0_level_0,latent_0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,...,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
19196,-0.303843,0.360902,-0.228601,0.335436,-0.271603,-0.093063,-0.340502,0.174494,-0.460498,-0.284710,...,True,,,,,,,9,records100/19000/19196_lr,records500/19000/19196_hr
3668,-0.273493,0.290129,-0.691656,-0.049703,-0.177563,0.215130,-0.264510,0.139445,0.005577,0.629999,...,True,,,,,,,8,records100/03000/03668_lr,records500/03000/03668_hr
14208,0.218258,-0.081778,0.528267,-0.077789,-0.442081,-0.000252,0.624470,0.127420,-0.055355,-0.250661,...,True,,,,,,,7,records100/14000/14208_lr,records500/14000/14208_hr
11998,0.167583,-0.053707,0.181361,0.105666,0.006340,-0.956509,-0.235292,-0.107994,0.182667,-0.821843,...,True,,,,,,,1,records100/11000/11998_lr,records500/11000/11998_hr
20383,0.041010,-0.544289,-0.131169,-0.000858,-0.416247,0.216033,0.439566,-0.140734,0.240704,-0.148970,...,True,,,,,,,5,records100/20000/20383_lr,records500/20000/20383_hr
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10243,-1.391763,-0.603890,0.254692,0.121041,0.112623,-0.047975,0.194222,-0.146493,-0.223625,-0.204178,...,True,,,,,,,4,records100/10000/10243_lr,records500/10000/10243_hr
18045,-0.498238,-0.261414,0.157817,0.331121,0.256370,-0.540003,0.125904,-0.071753,1.491374,-0.258937,...,False,,,,,,,6,records100/18000/18045_lr,records500/18000/18045_hr
16081,0.290756,-0.431782,0.446510,0.347812,0.150811,0.238983,0.106061,-0.257096,0.276724,0.169287,...,True,,,,,,,6,records100/16000/16081_lr,records500/16000/16081_hr
15854,-0.527246,0.054852,0.100475,0.506417,0.200829,0.545346,-0.345091,-0.611054,-0.205950,-0.043287,...,True,,,,,,,3,records100/15000/15854_lr,records500/15000/15854_hr


In [5]:
import ast
from util import eval_predictive_power_binary_outcome

all_scps = pd.read_csv("../data/scp_statements.csv", index_col=0)

def ptb_val_to_pseudobinary_label(scp_code_of_interest: str, this_recording_scp_codes: str):
    if scp_code_of_interest not in this_recording_scp_codes.keys():
        return 0.0
    elif scp_code_of_interest in this_recording_scp_codes.keys() and this_recording_scp_codes[scp_code_of_interest] == 100.0:
        return 1.0
    else:
        return float('nan')


results = list()
for scp_code in tqdm(all_scps.index.to_list()):
    combined_df[f'scp.{scp_code}'] = combined_df['scp_codes'].apply(lambda codes: ptb_val_to_pseudobinary_label(scp_code, ast.literal_eval(codes)))
    relevant_df = combined_df[~combined_df[f'scp.{scp_code}'].isna()]

    res = eval_predictive_power_binary_outcome(relevant_df[latent_df.columns], relevant_df[f'scp.{scp_code}'])
    res['Target'] = all_scps.loc[scp_code]['description']
    results.append(res)


results_df = pd.DataFrame.from_records(results)
results_df.nlargest(n=50, columns=['Avg CV score'])


  0%|          | 0/71 [00:00<?, ?it/s]

Unnamed: 0,Total usable,% positive,Avg CV score,Target
15,1886,0.027041,0.993138,complete left bundle branch block
14,1886,0.020148,0.989671,complete right bundle branch block
8,1884,0.064756,0.969234,left anterior fascicular block
19,1877,0.008524,0.965334,anterolateral myocardial infarction
7,1810,0.049724,0.93989,left ventricular hypertrophy
16,1869,0.010166,0.919955,inferolateral myocardial infarction
21,1886,0.009544,0.917776,subendocardial injury in anteroseptal leads
9,1874,0.040555,0.912558,non-specific ischemic
6,1821,0.065349,0.904692,anteroseptal myocardial infarction
4,1674,0.419355,0.899689,normal ECG
