In [1]:
from ptbxlae.dataprocessing.dataModules import SingleCycleCachedDM
from ptbxlae.modeling.convolutionalVAE import ConvolutionalEcgVAE

import pandas as pd
from tqdm.auto import tqdm
import os
import torch
import numpy as np

pd.options.mode.chained_assignment = None  # default='warn'

* PTB-XL Autoencoder *


# Get Latent Representations

In [2]:
dm = SingleCycleCachedDM(cache_folder="../cache/singlecycle_data")
dm.setup(stage="test")
metadata = dm.test_ds.dataset.metadata

torch.set_grad_enabled(False)
m = ConvolutionalEcgVAE.load_from_checkpoint('../cache/archivedmodels/scc-epoch=081-val_loss=332.327423.ckpt').eval()
m.cpu()
m

ConvolutionalEcgVAE(
  (encoder): ConvolutionalEcgEncoder(
    (net): Sequential(
      (0): Conv1d(12, 24, kernel_size=(13,), stride=(2,), padding=(6,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(24, 48, kernel_size=(13,), stride=(2,), padding=(6,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Flatten(start_dim=1, end_dim=-1)
      (5): Linear(in_features=6000, out_features=1500, bias=True)
      (6): LeakyReLU(negative_slope=0.01)
    )
  )
  (decoder): ConvolutionalEcgDecoder(
    (net): Sequential(
      (0): Linear(in_features=40, out_features=1500, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=1500, out_features=6000, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Unflatten(dim=1, unflattened_size=(48, 125))
      (5): ConvTranspose1d(48, 24, kernel_size=(13,), stride=(2,), padding=(6,), output_padding=(1,))
      (6): LeakyReLU(negative_slope=0.01)
      (7): ConvTranspose1d(24, 12, kernel_size=(13,),

In [3]:

latent_dicts = list()

collected_latents = list()

for test_index in tqdm(dm.test_ds.indices):
    pid = dm.test_ds.dataset.patient_ids[test_index]
    patient_dir = f"../cache/singlecycle_data/{pid}"

    # For purposes of testing, only consider first ecg in patient directory
    ecg_id = os.listdir(patient_dir)[0]
    ecg_dir = f"{patient_dir}/{ecg_id}"
    cycles = os.listdir(ecg_dir)

    batched_cycles = np.stack([pd.read_parquet(f"{ecg_dir}/{c}").to_numpy().transpose() for c in cycles])

    latent_representations = m.encode(torch.Tensor(batched_cycles)).mean(dim=0)
    
    labeled_series = pd.Series(data=latent_representations, index=[f'latent_{x}' for x in range(0, m.encoder.architecture_params.latent_dim)])
    labeled_series['ecg_id'] = int(ecg_id)
    # labeled_series['patient_id'] = int(pid)
    
    collected_latents.append(labeled_series)


latent_df = pd.concat(collected_latents, axis=1).T
latent_df['ecg_id'] = latent_df['ecg_id'].astype(int)
latent_df = latent_df.set_index('ecg_id')
latent_df

  0%|          | 0/1886 [00:00<?, ?it/s]

Unnamed: 0_level_0,latent_0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,...,latent_30,latent_31,latent_32,latent_33,latent_34,latent_35,latent_36,latent_37,latent_38,latent_39
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
19196,0.292990,0.541875,-0.237049,-0.032857,-0.351980,0.733638,-0.313003,0.128665,-0.099631,-0.234924,...,0.895284,0.809396,-0.318718,0.608100,0.056431,-0.199841,0.083749,0.023060,0.090378,0.222820
3668,-0.165551,-0.081546,-0.193163,0.042614,0.654696,0.255240,0.468165,-0.555480,0.244313,-0.285057,...,0.137633,-0.972584,0.181875,-0.370064,-0.065797,0.079240,0.268899,-0.290898,0.152508,-0.496413
14208,-0.416336,-0.250682,0.951321,0.190291,-0.489754,0.229042,-0.292869,-0.292163,0.298534,-0.071221,...,0.035932,0.218919,-0.041759,0.018164,-0.008204,-0.072490,0.122405,-0.582753,0.086064,0.044200
11998,-0.370392,0.030018,0.046648,0.480106,0.005712,-0.193515,0.078482,-0.000124,0.305600,-0.327691,...,-0.400573,0.038958,-0.150129,0.294635,-0.226235,0.292769,-0.217610,-0.490261,0.174991,-0.163576
20383,0.064888,-0.436333,0.576494,0.111732,-0.293115,-0.911887,0.158799,-0.266280,-0.179614,-0.214653,...,-0.489319,0.069782,-0.101244,0.543064,-0.127098,0.025844,-0.455001,0.054589,-0.136881,-0.081879
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10243,-1.394182,-0.590800,0.530042,0.161419,0.388833,-0.198705,-0.410419,0.362616,-0.022701,0.137373,...,-0.394444,-1.295954,-0.242687,0.699665,0.011117,0.762537,-0.145483,0.830036,0.538386,-0.418469
18045,-0.674545,0.290137,0.138001,0.173740,0.396240,-0.192699,0.182556,0.486441,0.462152,0.237036,...,0.238970,-1.513354,-0.016000,-0.091952,0.008034,0.374174,-0.124548,0.187693,-0.033377,0.265839
16081,-0.439224,-0.359716,0.146834,0.494796,0.214289,0.050397,0.054304,0.015023,0.250305,0.342573,...,0.159463,-0.395124,0.332184,0.288136,0.079799,0.521756,-0.392505,0.228625,-0.270765,0.051066
15854,-0.242156,-0.042555,0.219098,-0.305329,0.262643,0.511476,0.092298,0.642343,0.551860,-0.221172,...,0.164159,-0.220878,-0.880030,0.629271,-0.218347,-0.036274,0.202233,0.087937,0.578818,-0.204798


# Assess Predictive Power of Latent Representations for Each Diagnostic Label in PTB

In [4]:
metadata = dm.test_ds.dataset.metadata
combined_df = pd.merge(latent_df, metadata, how='left', left_index=True, right_index=True)
combined_df.to_parquet("../cache/eval_pipeline.parquet")
combined_df

Unnamed: 0_level_0,latent_0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,...,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
19196,0.292990,0.541875,-0.237049,-0.032857,-0.351980,0.733638,-0.313003,0.128665,-0.099631,-0.234924,...,True,,,,,,,9,records100/19000/19196_lr,records500/19000/19196_hr
3668,-0.165551,-0.081546,-0.193163,0.042614,0.654696,0.255240,0.468165,-0.555480,0.244313,-0.285057,...,True,,,,,,,8,records100/03000/03668_lr,records500/03000/03668_hr
14208,-0.416336,-0.250682,0.951321,0.190291,-0.489754,0.229042,-0.292869,-0.292163,0.298534,-0.071221,...,True,,,,,,,7,records100/14000/14208_lr,records500/14000/14208_hr
11998,-0.370392,0.030018,0.046648,0.480106,0.005712,-0.193515,0.078482,-0.000124,0.305600,-0.327691,...,True,,,,,,,1,records100/11000/11998_lr,records500/11000/11998_hr
20383,0.064888,-0.436333,0.576494,0.111732,-0.293115,-0.911887,0.158799,-0.266280,-0.179614,-0.214653,...,True,,,,,,,5,records100/20000/20383_lr,records500/20000/20383_hr
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10243,-1.394182,-0.590800,0.530042,0.161419,0.388833,-0.198705,-0.410419,0.362616,-0.022701,0.137373,...,True,,,,,,,4,records100/10000/10243_lr,records500/10000/10243_hr
18045,-0.674545,0.290137,0.138001,0.173740,0.396240,-0.192699,0.182556,0.486441,0.462152,0.237036,...,False,,,,,,,6,records100/18000/18045_lr,records500/18000/18045_hr
16081,-0.439224,-0.359716,0.146834,0.494796,0.214289,0.050397,0.054304,0.015023,0.250305,0.342573,...,True,,,,,,,6,records100/16000/16081_lr,records500/16000/16081_hr
15854,-0.242156,-0.042555,0.219098,-0.305329,0.262643,0.511476,0.092298,0.642343,0.551860,-0.221172,...,True,,,,,,,3,records100/15000/15854_lr,records500/15000/15854_hr


In [5]:
import ast
from util import eval_predictive_power_binary_outcome

all_scps = pd.read_csv("../data/scp_statements.csv", index_col=0)

def ptb_val_to_pseudobinary_label(scp_code_of_interest: str, this_recording_scp_codes: str):
    if scp_code_of_interest not in this_recording_scp_codes.keys():
        return 0.0
    elif scp_code_of_interest in this_recording_scp_codes.keys() and this_recording_scp_codes[scp_code_of_interest] == 100.0:
        return 1.0
    else:
        return float('nan')


results = list()
for scp_code in tqdm(all_scps.index.to_list()):
    combined_df[f'scp.{scp_code}'] = combined_df['scp_codes'].apply(lambda codes: ptb_val_to_pseudobinary_label(scp_code, ast.literal_eval(codes)))
    relevant_df = combined_df[~combined_df[f'scp.{scp_code}'].isna()]

    res = eval_predictive_power_binary_outcome(relevant_df[latent_df.columns], relevant_df[f'scp.{scp_code}'])
    res['Target'] = all_scps.loc[scp_code]['description']
    results.append(res)


results_df = pd.DataFrame.from_records(results)
results_df.nlargest(n=50, columns=['Avg CV score'])


  0%|          | 0/71 [00:00<?, ?it/s]

Unnamed: 0,Total usable,% positive,Avg CV score,Target
15,1886,0.027041,0.994783,complete left bundle branch block
14,1886,0.020148,0.988408,complete right bundle branch block
8,1884,0.064756,0.96829,left anterior fascicular block
7,1810,0.049724,0.942571,left ventricular hypertrophy
19,1877,0.008524,0.942549,anterolateral myocardial infarction
21,1886,0.009544,0.917884,subendocardial injury in anteroseptal leads
9,1874,0.040555,0.912671,non-specific ischemic
6,1821,0.065349,0.91156,anteroseptal myocardial infarction
16,1869,0.010166,0.910541,inferolateral myocardial infarction
4,1674,0.419355,0.899769,normal ECG
