In [1]:
from ptbxlae.dataprocessing.dataModules import SingleCycleCachedDM
import pandas as pd
from tqdm.auto import tqdm

pd.options.mode.chained_assignment = None  # default='warn'

* PTB-XL Autoencoder *


# Get Latent Representations

In [2]:
from ptbxlae.modeling.singleCycleConv import SingleCycleConvVAE
import torch

dm = SingleCycleCachedDM(cache_folder="../cache/singlecycle_data")
dm.setup(stage="test")
dl = dm.test_dataloader()

m = SingleCycleConvVAE.load_from_checkpoint('../cache/savedmodels/last-v1.ckpt').eval()
m.cpu()

latent_representations = torch.cat([m.encode(x) for x in tqdm(dl)]).cpu().detach().numpy()


/home/isears/VirtualEnvironments/default/lib/python3.11/site-packages/lightning/fabric/utilities/cloud_io.py:56: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


  0%|          | 0/59 [00:00<?, ?it/s]

In [3]:
latent_df = pd.DataFrame(data=latent_representations, columns=[f"latent.{i}" for i in range(0, 25)])
latent_df

Unnamed: 0,latent.0,latent.1,latent.2,latent.3,latent.4,latent.5,latent.6,latent.7,latent.8,latent.9,...,latent.15,latent.16,latent.17,latent.18,latent.19,latent.20,latent.21,latent.22,latent.23,latent.24
0,-1.331678,0.019534,0.423945,0.158147,0.161493,2.691121,-0.373091,0.319787,0.357783,-1.129923,...,-1.052381,0.063345,0.413455,-0.175104,-0.151516,-0.108899,0.163770,-1.517453,0.628882,-0.474045
1,-0.614352,-0.771474,0.877454,0.566261,-1.542497,0.333332,0.849198,0.689939,0.288457,1.298107,...,-0.208082,0.300257,-0.696106,0.693703,0.699262,1.763966,-0.517520,-0.986676,1.616176,0.819023
2,0.448710,0.467811,-0.379832,0.348398,-0.139661,1.127355,1.490121,2.123297,1.350752,0.028545,...,0.007084,0.152813,-0.993664,-1.644603,-1.330565,1.421050,2.453838,1.338293,-0.303993,-0.863238
3,-0.779672,-1.227710,0.027010,2.174850,1.672388,-0.139133,-0.499847,-1.556339,0.405783,0.432769,...,-0.050415,0.784577,-0.015156,-0.077316,-1.523136,-0.097421,-0.684979,0.270716,-0.065140,-0.141831
4,1.127179,-1.322895,0.431583,1.024616,-0.193665,-0.340802,0.923903,0.071978,1.288311,-0.027408,...,-0.011160,1.704053,-1.026659,1.321317,-0.776751,-0.982715,0.178803,1.537049,0.400800,1.358041
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1881,1.516409,-0.954753,2.073724,-3.148396,0.695662,0.219924,-1.482243,0.062958,-0.559998,-0.532423,...,0.302150,-1.566902,1.050687,0.370122,-1.860046,-0.031486,0.746346,-1.316110,0.346080,1.872739
1882,-0.204663,-0.312426,0.519306,1.038776,1.391322,-0.255679,-0.553429,-0.287995,-0.566991,0.794041,...,-1.116883,-0.469963,1.006721,0.546989,-0.268800,-1.060930,-0.890039,-0.580312,-1.065618,0.606669
1883,1.462117,-0.800125,-0.006636,-0.837234,0.787176,-1.686111,-0.861688,0.411846,-0.509612,-1.265883,...,1.334992,-0.587565,0.794985,0.470822,-0.212306,-1.560618,-0.243005,0.168051,-1.836245,0.079734
1884,0.306671,0.773609,-1.082525,1.668587,0.202265,-1.907422,0.085877,-0.307616,-0.635959,-0.093182,...,0.650112,2.468080,0.798801,-0.020331,0.524527,-0.063154,-1.697899,0.439860,1.190037,1.018127


# Util Function: LR Based on Latent Representations for Specified Targets (y)

In [4]:
# Util function to do a quick LR on a specified target based on latent variables
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

def eval_predictive_power_binary_outcome(x, y):
    lr = LogisticRegression()

    ret = {
        "Total usable": len(y),
        "% positive": sum(y) / len(y)
    }

    if sum(y) < 10:  # Won't be able to do CV
        ret['Avg CV score'] = float('nan')
        return ret

    try:
        scores = cross_val_score(lr, x, y, cv=5, scoring='roc_auc')
        ret['Avg CV score'] = sum(scores) / len(scores)
        return ret
    except ValueError:
        ret['Avg CV score'] = float('nan')
        return ret

# Assess Predictive Power of Latent Representations for Each Diagnostic Label in PTB

In [5]:
import ast

dm.setup(stage="test")
metadata = dm.test_ds.dataset.metadata.iloc[dm.test_ds.indices].reset_index()

all_scps = pd.read_csv("../data/scp_statements.csv", index_col=0)

def ptb_val_to_pseudobinary_label(scp_code_of_interest: str, this_recording_scp_codes: str):
    if scp_code_of_interest not in this_recording_scp_codes.keys():
        return 0.0
    elif scp_code_of_interest in this_recording_scp_codes.keys() and this_recording_scp_codes[scp_code_of_interest] == 100.0:
        return 1.0
    else:
        return float('nan')

combined_df = pd.concat([metadata, latent_df], axis=1)

results = list()
for scp_code in tqdm(all_scps.index.to_list()):
    combined_df[f'scp.{scp_code}'] = combined_df['scp_codes'].apply(lambda codes: ptb_val_to_pseudobinary_label(scp_code, ast.literal_eval(codes)))
    relevant_df = combined_df[~combined_df[f'scp.{scp_code}'].isna()]

    res = eval_predictive_power_binary_outcome(relevant_df[latent_df.columns], relevant_df[f'scp.{scp_code}'])
    res['Target'] = all_scps.loc[scp_code]['description']
    results.append(res)


results_df = pd.DataFrame.from_records(results)
results_df.nlargest(n=50, columns=['Avg CV score'])


  0%|          | 0/71 [00:00<?, ?it/s]

Unnamed: 0,Total usable,% positive,Avg CV score,Target
25,1884,0.005839,0.67249,ischemic in anteroseptal leads
21,1886,0.005302,0.666434,subendocardial injury in anteroseptal leads
20,1878,0.006922,0.647542,ischemic in inferior leads
19,1875,0.005867,0.636206,anterolateral myocardial infarction
10,1882,0.053135,0.594204,incomplete right bundle branch block
6,1813,0.079978,0.590447,anteroseptal myocardial infarction
27,1886,0.006893,0.584717,ischemic in lateral leads
1,1856,0.024246,0.539239,non-specific ST changes
17,1876,0.017591,0.528445,left atrial overload/enlargement
12,1886,0.036055,0.525635,non-specific intraventricular conduction distu...
