In [1]:
from ptbxlae.dataprocessing.dataModules import SingleCycleCachedDM
from ptbxlae.modeling.convolutionalVAE import ConvolutionalEcgVAE

import pandas as pd
from tqdm.auto import tqdm
import os
import torch
import numpy as np

pd.options.mode.chained_assignment = None  # default='warn'

* PTB-XL Autoencoder *


# Get Latent Representations

In [6]:
dm = SingleCycleCachedDM(cache_folder="../cache/singlecycle_data")
dm.setup(stage="test")
metadata = dm.test_ds.dataset.metadata

torch.set_grad_enabled(False)
m = ConvolutionalEcgVAE.load_from_checkpoint('../cache/savedmodels/last.ckpt').eval()
m.cpu()
m

ConvolutionalEcgVAE(
  (encoder): ConvolutionalEcgEncoder(
    (net): Sequential(
      (0): Conv1d(12, 24, kernel_size=(5,), stride=(2,), padding=(2,))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv1d(24, 48, kernel_size=(5,), stride=(2,), padding=(2,))
      (3): LeakyReLU(negative_slope=0.01)
      (4): Flatten(start_dim=1, end_dim=-1)
    )
  )
  (decoder): ConvolutionalEcgDecoder(
    (net): Sequential(
      (0): Linear(in_features=40, out_features=6000, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Unflatten(dim=1, unflattened_size=(48, 125))
      (3): ConvTranspose1d(48, 24, kernel_size=(5,), stride=(2,), padding=(2,), output_padding=(1,))
      (4): LeakyReLU(negative_slope=0.01)
      (5): ConvTranspose1d(24, 12, kernel_size=(5,), stride=(2,), padding=(2,), output_padding=(1,))
      (6): LeakyReLU(negative_slope=0.01)
    )
  )
  (fc_mean): Linear(in_features=6000, out_features=40, bias=True)
  (fc_logvar): Linear(in_features=6000, out_features=

In [7]:

latent_dicts = list()

collected_latents = list()

for test_index in tqdm(dm.test_ds.indices):
    pid = dm.test_ds.dataset.patient_ids[test_index]
    patient_dir = f"../cache/singlecycle_data/{pid}"

    # For purposes of testing, only consider first ecg in patient directory
    ecg_id = os.listdir(patient_dir)[0]
    ecg_dir = f"{patient_dir}/{ecg_id}"
    cycles = os.listdir(ecg_dir)

    batched_cycles = np.stack([pd.read_parquet(f"{ecg_dir}/{c}").to_numpy().transpose() for c in cycles])

    latent_representations = m.encode(torch.Tensor(batched_cycles)).mean(dim=0)
    
    labeled_series = pd.Series(data=latent_representations, index=[f'latent_{x}' for x in range(0, m.encoder.architecture_params.latent_dim)])
    labeled_series['ecg_id'] = int(ecg_id)
    # labeled_series['patient_id'] = int(pid)
    
    collected_latents.append(labeled_series)


latent_df = pd.concat(collected_latents, axis=1).T
latent_df['ecg_id'] = latent_df['ecg_id'].astype(int)
latent_df = latent_df.set_index('ecg_id')
latent_df

  0%|          | 0/1886 [00:00<?, ?it/s]

Unnamed: 0_level_0,latent_0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,...,latent_30,latent_31,latent_32,latent_33,latent_34,latent_35,latent_36,latent_37,latent_38,latent_39
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
19196,-0.132937,0.235363,0.221600,0.093827,0.030118,0.082465,0.040064,-0.510622,-0.101321,-0.170131,...,0.101099,-0.461166,-0.029380,-0.068751,0.103694,-0.018498,-0.033903,0.249414,-0.374187,0.145589
3668,-0.311308,-0.265534,-0.525946,0.288166,0.192659,0.347320,-0.390033,0.240070,-0.554002,0.198221,...,0.555309,-0.181685,0.017443,0.674247,-0.101514,0.064378,-0.331595,-0.157873,0.674439,-0.400453
14208,0.027221,-0.202348,-0.377983,0.442892,-0.142330,0.356758,0.111699,-0.191964,0.434534,-0.141312,...,-0.472185,-0.430180,0.185014,-0.490180,0.119323,0.150279,0.119628,0.012445,0.030761,-0.105781
11998,-0.224886,0.012884,0.360798,0.121988,0.213003,-0.186658,-0.316999,0.147217,-0.199629,0.134238,...,0.189048,0.004496,-0.260264,-0.391592,0.452644,0.721014,0.431145,0.065960,0.188932,0.071957
20383,-0.697016,0.310692,0.128409,-0.590824,-0.733731,0.702062,0.198369,0.178599,0.103363,0.579671,...,-0.039232,-0.254711,-0.313832,0.271175,0.058339,0.830758,0.312719,0.054614,-0.018727,-0.331798
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10243,0.245220,0.149432,0.335953,0.119903,0.239621,0.317397,0.197524,0.111434,-0.412202,-0.050227,...,0.092058,0.134485,-0.087123,-0.120176,0.254068,-0.040172,-0.000807,-0.270670,-0.265856,0.016476
18045,0.246490,0.049270,0.077744,-0.075257,-0.333197,-0.571181,-0.626580,-0.194702,-0.058584,0.013502,...,-0.387583,0.004530,-0.151820,-0.062971,0.179538,0.260980,-0.172972,0.024776,-0.492894,-0.433465
16081,0.073206,0.036987,0.135165,0.286527,0.016248,-0.113701,0.232936,-0.430729,-0.245568,-0.168867,...,0.054568,0.188731,-0.543949,-0.324444,-0.319623,-0.438653,-0.045533,0.039570,0.068965,-0.218527
15854,-0.111148,0.428887,0.052729,0.205441,-0.340609,-0.478181,0.329602,0.232590,-0.398275,0.221088,...,0.837985,-0.111623,-0.440812,-0.422998,-0.040586,-0.045266,-0.574199,0.195015,0.630407,0.219369


# Assess Predictive Power of Latent Representations for Each Diagnostic Label in PTB

In [8]:
metadata = dm.test_ds.dataset.metadata
combined_df = pd.merge(latent_df, metadata, how='left', left_index=True, right_index=True)
combined_df

Unnamed: 0_level_0,latent_0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,...,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
19196,-0.132937,0.235363,0.221600,0.093827,0.030118,0.082465,0.040064,-0.510622,-0.101321,-0.170131,...,True,,,,,,,9,records100/19000/19196_lr,records500/19000/19196_hr
3668,-0.311308,-0.265534,-0.525946,0.288166,0.192659,0.347320,-0.390033,0.240070,-0.554002,0.198221,...,True,,,,,,,8,records100/03000/03668_lr,records500/03000/03668_hr
14208,0.027221,-0.202348,-0.377983,0.442892,-0.142330,0.356758,0.111699,-0.191964,0.434534,-0.141312,...,True,,,,,,,7,records100/14000/14208_lr,records500/14000/14208_hr
11998,-0.224886,0.012884,0.360798,0.121988,0.213003,-0.186658,-0.316999,0.147217,-0.199629,0.134238,...,True,,,,,,,1,records100/11000/11998_lr,records500/11000/11998_hr
20383,-0.697016,0.310692,0.128409,-0.590824,-0.733731,0.702062,0.198369,0.178599,0.103363,0.579671,...,True,,,,,,,5,records100/20000/20383_lr,records500/20000/20383_hr
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10243,0.245220,0.149432,0.335953,0.119903,0.239621,0.317397,0.197524,0.111434,-0.412202,-0.050227,...,True,,,,,,,4,records100/10000/10243_lr,records500/10000/10243_hr
18045,0.246490,0.049270,0.077744,-0.075257,-0.333197,-0.571181,-0.626580,-0.194702,-0.058584,0.013502,...,False,,,,,,,6,records100/18000/18045_lr,records500/18000/18045_hr
16081,0.073206,0.036987,0.135165,0.286527,0.016248,-0.113701,0.232936,-0.430729,-0.245568,-0.168867,...,True,,,,,,,6,records100/16000/16081_lr,records500/16000/16081_hr
15854,-0.111148,0.428887,0.052729,0.205441,-0.340609,-0.478181,0.329602,0.232590,-0.398275,0.221088,...,True,,,,,,,3,records100/15000/15854_lr,records500/15000/15854_hr


In [9]:
import ast
from util import eval_predictive_power_binary_outcome

all_scps = pd.read_csv("../data/scp_statements.csv", index_col=0)

def ptb_val_to_pseudobinary_label(scp_code_of_interest: str, this_recording_scp_codes: str):
    if scp_code_of_interest not in this_recording_scp_codes.keys():
        return 0.0
    elif scp_code_of_interest in this_recording_scp_codes.keys() and this_recording_scp_codes[scp_code_of_interest] == 100.0:
        return 1.0
    else:
        return float('nan')


results = list()
for scp_code in tqdm(all_scps.index.to_list()):
    combined_df[f'scp.{scp_code}'] = combined_df['scp_codes'].apply(lambda codes: ptb_val_to_pseudobinary_label(scp_code, ast.literal_eval(codes)))
    relevant_df = combined_df[~combined_df[f'scp.{scp_code}'].isna()]

    res = eval_predictive_power_binary_outcome(relevant_df[latent_df.columns], relevant_df[f'scp.{scp_code}'])
    res['Target'] = all_scps.loc[scp_code]['description']
    results.append(res)


results_df = pd.DataFrame.from_records(results)
results_df.nlargest(n=50, columns=['Avg CV score'])


  0%|          | 0/71 [00:00<?, ?it/s]

Unnamed: 0,Total usable,% positive,Avg CV score,Target
20,1884,0.005308,0.783836,ischemic in inferior leads
18,1869,0.006956,0.578057,anterior myocardial infarction
7,1810,0.049724,0.56531,left ventricular hypertrophy
9,1874,0.040555,0.561432,non-specific ischemic
45,1872,0.043803,0.557907,ventricular premature complex
0,1884,0.084926,0.552794,non-diagnostic T abnormalities
6,1821,0.065349,0.541224,anteroseptal myocardial infarction
17,1878,0.011715,0.539493,left atrial overload/enlargement
21,1886,0.009544,0.538551,subendocardial injury in anteroseptal leads
14,1886,0.020148,0.536419,complete right bundle branch block
