In [1]:
from ptbxlae.dataprocessing.dataModules import SingleCycleCachedDM
from ptbxlae.modeling.singleCycleConv import SingleCycleConvVAE

import pandas as pd
from tqdm.auto import tqdm
import os
import torch
import numpy as np

pd.options.mode.chained_assignment = None  # default='warn'

* PTB-XL Autoencoder *


# Get Latent Representations

In [2]:
dm = SingleCycleCachedDM(cache_folder="../cache/singlecycle_data")
dm.setup(stage="test")
metadata = dm.test_ds.dataset.metadata

torch.no_grad()
m = SingleCycleConvVAE.load_from_checkpoint('../cache/savedmodels/last-v2.ckpt').eval()
m.cpu()

latent_dicts = list()

collected_latents = list()

for test_index in tqdm(dm.test_ds.indices):
    pid = dm.test_ds.dataset.patient_ids[test_index]
    patient_dir = f"../cache/singlecycle_data/{pid}"

    # For purposes of testing, only consider first ecg in patient directory
    ecg_id = os.listdir(patient_dir)[0]
    ecg_dir = f"{patient_dir}/{ecg_id}"
    cycles = os.listdir(ecg_dir)

    batched_cycles = np.stack([pd.read_parquet(f"{ecg_dir}/{c}").to_numpy().transpose() for c in cycles])

    latent_representations = m.encode(torch.Tensor(batched_cycles)).mean(dim=0).detach()
    
    labeled_series = pd.Series(data=latent_representations, index=[f'latent_{x}' for x in range(0, m.latent_dim)])
    labeled_series['ecg_id'] = int(ecg_id)
    # labeled_series['patient_id'] = int(pid)
    
    collected_latents.append(labeled_series)


latent_df = pd.concat(collected_latents, axis=1).T.set_index('ecg_id')
latent_df

/home/isears/VirtualEnvironments/default/lib/python3.11/site-packages/lightning/fabric/utilities/cloud_io.py:56: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


  0%|          | 0/1886 [00:00<?, ?it/s]

Unnamed: 0_level_0,latent_0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,...,latent_30,latent_31,latent_32,latent_33,latent_34,latent_35,latent_36,latent_37,latent_38,latent_39
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
19196.0,0.056560,-0.248328,0.469551,-0.125453,-0.468456,0.461901,-0.497111,0.507668,0.424205,0.184844,...,-0.362406,-0.049523,0.088597,-0.085305,0.016143,-0.424340,-0.174764,-0.403961,0.078723,-0.289868
3668.0,0.234401,-0.204226,-0.913933,-0.111487,1.001097,0.417763,0.447499,0.414569,0.231814,0.500907,...,0.297847,0.179672,-0.394955,-0.293370,-0.110164,0.123528,-0.436464,0.072824,0.707503,-0.056817
14208.0,-0.532175,-0.062223,-1.214666,-0.418501,-0.442498,0.799148,0.274425,0.289443,-0.159931,-0.326377,...,-0.190818,-0.246083,-0.434323,-0.037794,0.182878,0.391407,0.064133,0.109485,-1.449176,-0.367159
11998.0,-0.169199,-0.088786,0.766018,0.304525,0.388508,0.007866,0.566686,-0.127804,0.161397,0.351331,...,1.049050,-1.077612,-0.513012,0.242399,0.254880,-0.132012,-0.107794,-0.295822,0.374165,0.001363
20383.0,-0.389747,0.023635,-1.171100,-0.606851,0.289316,-0.461197,0.261198,0.069097,0.076556,-0.239444,...,1.193573,-0.258209,0.038261,-0.092156,-0.110255,0.527212,0.070424,-0.368724,0.794358,0.163235
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10243.0,-0.557648,0.353242,1.083321,-0.364656,1.951932,0.589303,0.918711,0.016154,-0.077019,1.657878,...,-2.533943,-0.923792,0.074770,0.105297,0.406131,-0.309094,0.339756,0.514692,-0.468235,-0.074978
18045.0,-0.229511,-0.036815,0.838845,0.047916,1.473472,0.445969,0.300061,0.003369,-0.133302,0.595042,...,0.224626,-0.079164,-0.043212,0.306185,0.112613,-0.081211,0.263763,-0.329937,0.298368,0.029271
16081.0,-0.614988,-0.046837,0.871279,-0.498187,0.525017,0.837033,0.261638,-0.142268,-0.278457,-0.197514,...,0.392861,0.283502,0.217922,-0.187737,0.396323,-0.095462,0.321989,0.056623,-0.167848,0.915674
15854.0,-0.818308,-0.289046,-0.013039,-0.070141,-0.011235,0.114491,-0.048763,1.225818,0.474634,-0.693645,...,1.354771,-0.895718,0.710110,0.448327,-0.280324,0.047711,0.645329,0.108266,2.511841,-0.636431


# Util Function: LR Based on Latent Representations for Specified Targets (y)

In [3]:
# Util function to do a quick LR on a specified target based on latent variables
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

def eval_predictive_power_binary_outcome(x, y):
    lr = LogisticRegression()

    ret = {
        "Total usable": len(y),
        "% positive": sum(y) / len(y)
    }

    if sum(y) < 10:  # Won't be able to do CV
        ret['Avg CV score'] = float('nan')
        return ret

    try:
        scores = cross_val_score(lr, x, y, cv=5, scoring='roc_auc')
        ret['Avg CV score'] = sum(scores) / len(scores)
        return ret
    except ValueError:
        ret['Avg CV score'] = float('nan')
        return ret

# Assess Predictive Power of Latent Representations for Each Diagnostic Label in PTB

In [4]:
metadata = dm.test_ds.dataset.metadata
combined_df = pd.merge(latent_df, metadata, how='left', left_index=True, right_index=True)
combined_df

Unnamed: 0_level_0,latent_0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,...,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
19196.0,0.056560,-0.248328,0.469551,-0.125453,-0.468456,0.461901,-0.497111,0.507668,0.424205,0.184844,...,True,,,,,,,9,records100/19000/19196_lr,records500/19000/19196_hr
3668.0,0.234401,-0.204226,-0.913933,-0.111487,1.001097,0.417763,0.447499,0.414569,0.231814,0.500907,...,True,,,,,,,8,records100/03000/03668_lr,records500/03000/03668_hr
14208.0,-0.532175,-0.062223,-1.214666,-0.418501,-0.442498,0.799148,0.274425,0.289443,-0.159931,-0.326377,...,True,,,,,,,7,records100/14000/14208_lr,records500/14000/14208_hr
11998.0,-0.169199,-0.088786,0.766018,0.304525,0.388508,0.007866,0.566686,-0.127804,0.161397,0.351331,...,True,,,,,,,1,records100/11000/11998_lr,records500/11000/11998_hr
20383.0,-0.389747,0.023635,-1.171100,-0.606851,0.289316,-0.461197,0.261198,0.069097,0.076556,-0.239444,...,True,,,,,,,5,records100/20000/20383_lr,records500/20000/20383_hr
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10243.0,-0.557648,0.353242,1.083321,-0.364656,1.951932,0.589303,0.918711,0.016154,-0.077019,1.657878,...,True,,,,,,,4,records100/10000/10243_lr,records500/10000/10243_hr
18045.0,-0.229511,-0.036815,0.838845,0.047916,1.473472,0.445969,0.300061,0.003369,-0.133302,0.595042,...,False,,,,,,,6,records100/18000/18045_lr,records500/18000/18045_hr
16081.0,-0.614988,-0.046837,0.871279,-0.498187,0.525017,0.837033,0.261638,-0.142268,-0.278457,-0.197514,...,True,,,,,,,6,records100/16000/16081_lr,records500/16000/16081_hr
15854.0,-0.818308,-0.289046,-0.013039,-0.070141,-0.011235,0.114491,-0.048763,1.225818,0.474634,-0.693645,...,True,,,,,,,3,records100/15000/15854_lr,records500/15000/15854_hr


In [5]:
import ast

all_scps = pd.read_csv("../data/scp_statements.csv", index_col=0)

def ptb_val_to_pseudobinary_label(scp_code_of_interest: str, this_recording_scp_codes: str):
    if scp_code_of_interest not in this_recording_scp_codes.keys():
        return 0.0
    elif scp_code_of_interest in this_recording_scp_codes.keys() and this_recording_scp_codes[scp_code_of_interest] == 100.0:
        return 1.0
    else:
        return float('nan')


results = list()
for scp_code in tqdm(all_scps.index.to_list()):
    combined_df[f'scp.{scp_code}'] = combined_df['scp_codes'].apply(lambda codes: ptb_val_to_pseudobinary_label(scp_code, ast.literal_eval(codes)))
    relevant_df = combined_df[~combined_df[f'scp.{scp_code}'].isna()]

    res = eval_predictive_power_binary_outcome(relevant_df[latent_df.columns], relevant_df[f'scp.{scp_code}'])
    res['Target'] = all_scps.loc[scp_code]['description']
    results.append(res)


results_df = pd.DataFrame.from_records(results)
results_df.nlargest(n=50, columns=['Avg CV score'])


  0%|          | 0/71 [00:00<?, ?it/s]

Unnamed: 0,Total usable,% positive,Avg CV score,Target
15,1886,0.027041,0.996151,complete left bundle branch block
14,1886,0.020148,0.986773,complete right bundle branch block
8,1884,0.064756,0.965775,left anterior fascicular block
7,1810,0.049724,0.950291,left ventricular hypertrophy
9,1874,0.040555,0.926361,non-specific ischemic
21,1886,0.009544,0.913822,subendocardial injury in anteroseptal leads
6,1821,0.065349,0.906405,anteroseptal myocardial infarction
16,1869,0.010166,0.902658,inferolateral myocardial infarction
4,1674,0.419355,0.900138,normal ECG
19,1877,0.008524,0.897264,anterolateral myocardial infarction
