This notebook presents a step by step way to get text embeddings for each patient's report based on SCP codes and confidence scores provided in the PTB-XL dataset:

- Input: 
    - ptbxl_database.csv
    - scp_statements.csv
    - Model from emilyalsentzer/Bio_ClinicalBERT

- Output: 
    - A dictionary with patient ID and text embedding (dimension is 768)
    - Saved to patient_embedding_dict_summed_SCP_structured_w_confidence.pkl


In [None]:
import os
import numpy as np
import pickle 
import torch
import pandas as pd
import ast
from transformers import AutoTokenizer, AutoModel
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)
model.eval()



In [None]:
ptbxl_path = "../data/ptbxl/ptbxl_database.csv"
scp_statements_path = "../data/ptbxl/scp_statements.csv"
assert os.path.exists(ptbxl_path) and os.path.exists(scp_statements_path), "ptbxl_database.csv and scp_statements.csv must be downloaded from https://physionet.org/content/ptb-xl/1.0.1/ and placed in the data/ptbxl folder"

report_df = pd.read_csv(ptbxl_path)
report_df.head(5)

def convert_string_to_dict(string):
    return ast.literal_eval(string)
report_df['scp_codes_dict'] = report_df['scp_codes'].apply(convert_string_to_dict)
print ("len of patient",len(report_df))
report_df.head(5)

## iterate over the rows to remove the scp codes with 0

In [3]:
for index, row in report_df.iterrows():
    remove_keys = []
    for key in row['scp_codes_dict'].keys():
        if row['scp_codes_dict'][key]==0:
            remove_keys.append(key)
    for key in remove_keys:
        row['scp_codes_dict'].pop(key)

report_df.tail(5)

Unnamed: 0,ecg_id,patient_id,age,sex,height,weight,nurse,site,device,recording_date,...,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,scp_codes_dict
21832,21833,17180.0,67.0,1,,,1.0,2.0,AT-60 3,2001-05-31 09:14:35,...,,", alles,",,,1ES,,7,records100/21000/21833_lr,records500/21000/21833_hr,"{'NDT': 100.0, 'PVC': 100.0}"
21833,21834,20703.0,93.0,0,,,1.0,2.0,AT-60 3,2001-06-05 11:33:39,...,,,,,,,4,records100/21000/21834_lr,records500/21000/21834_hr,{'NORM': 100.0}
21834,21835,19311.0,59.0,1,,,1.0,2.0,AT-60 3,2001-06-08 10:30:27,...,,", I-AVR,",,,,,2,records100/21000/21835_lr,records500/21000/21835_hr,{'ISCAS': 50.0}
21835,21836,8873.0,64.0,1,,,1.0,2.0,AT-60 3,2001-06-09 18:21:49,...,,,,,SVES,,8,records100/21000/21836_lr,records500/21000/21836_hr,{'NORM': 100.0}
21836,21837,11744.0,68.0,0,,,1.0,2.0,AT-60 3,2001-06-11 16:43:01,...,,", I-AVL,",,,,,9,records100/21000/21837_lr,records500/21000/21837_hr,{'NORM': 100.0}


In [4]:
df = pd.read_csv(scp_statements_path)
print(len(df))
print("SCP code description")
df.head(20)

71
SCP code description


Unnamed: 0,key,description,diagnostic,form,rhythm,diagnostic_class,diagnostic_subclass,Statement Category,SCP-ECG Statement Description,AHA code,aECG REFID,CDISC Code,DICOM Code
0,NDT,non-diagnostic T abnormalities,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,non-diagnostic T abnormalities,,,,
1,NST_,non-specific ST changes,1.0,1.0,,STTC,NST_,Basic roots for coding ST-T changes and abnorm...,non-specific ST changes,145.0,MDC_ECG_RHY_STHILOST,,
2,DIG,digitalis-effect,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,suggests digitalis-effect,205.0,,,
3,LNGQT,long QT-interval,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,long QT-interval,148.0,,,
4,NORM,normal ECG,1.0,,,NORM,NORM,Normal/abnormal,normal ECG,1.0,,,F-000B7
5,IMI,inferior myocardial infarction,1.0,,,MI,IMI,Myocardial Infarction,inferior myocardial infarction,161.0,,,
6,ASMI,anteroseptal myocardial infarction,1.0,,,MI,AMI,Myocardial Infarction,anteroseptal myocardial infarction,165.0,,,
7,LVH,left ventricular hypertrophy,1.0,,,HYP,LVH,Ventricular Hypertrophy,left ventricular hypertrophy,142.0,,C71076,
8,LAFB,left anterior fascicular block,1.0,,,CD,LAFB/LPFB,Intraventricular and intra-atrial Conduction d...,left anterior fascicular block,101.0,MDC_ECG_BEAT_BLK_ANT_L_HEMI,C62267,D3-33140
9,ISC_,non-specific ischemic,1.0,,,STTC,ISC_,Basic roots for coding ST-T changes and abnorm...,ischemic ST-T changes,226.0,,,


## for each record, construct its text embeddings. 

In [None]:
from collections import OrderedDict
text_features = []
patient_embedding_dict = {}
## give an ordered dictionary to cache the SCP embeddings
cache_dict = OrderedDict()

for index, row in report_df.iterrows():
    # for each patient, get the SCP codes and their confidence scores
    sentence_count =0
    encoded_inputs_embedding_list = []
    uncertainty_list = []
    total_prompt = ''
    for key, uncertainty in row['scp_codes_dict'].items():
        sentence_count+=1
        scaled_uncertainty = uncertainty/100.0
        uncertainty_list.append(scaled_uncertainty)
        ## promp is defined here as the SCP code categorty and the SCP code description
        string_text = "[CLS] "+df[df['key']==key]['Statement Category'].values[0]+':'
        string_text+=df[df['key']==key]['SCP-ECG Statement Description'].values[0]+" [SCP]"
        print(string_text)
      
        if key in cache_dict.keys():
            ## if related SCP embedding is already calculated, directly use it
            embeddings = cache_dict[key]
            encoded_inputs_embedding_list.append(embeddings)
        else:
            encoded_inputs = tokenizer(string_text, add_special_tokens=False, truncation = True,return_tensors="pt",max_length=100, padding = 'max_length')
            input_ids = encoded_inputs['input_ids']
            segments_tensors = encoded_inputs['token_type_ids']
            attention_mask = encoded_inputs['attention_mask']
            with torch.inference_mode():
                model_output = model(input_ids = input_ids.to(device),attention_mask= attention_mask.to(device))
                embeddings = model_output.last_hidden_state
                embeddings = torch.mean(embeddings,keepdim=True,dim=1)
                embeddings = embeddings.squeeze(0)
            cache_dict.update({key:embeddings})
            encoded_inputs_embedding_list.append(embeddings)
    multi_embedding =None
    sum_uncertainty = sum(uncertainty_list)
    i= 0
    ## weighted average of SCP embeddings based on confidence scores
    for embedding,uncertainty in zip(encoded_inputs_embedding_list,uncertainty_list):
        if i==0:
            multi_embedding = embedding*(uncertainty/sum_uncertainty)
        else:multi_embedding+=embedding*(uncertainty/sum_uncertainty)
        i+=1
    print(multi_embedding.shape) 
    ## save the dictionary with patient ecg id as key,  and its text embedding as value
    patient_embedding_dict.update({row['ecg_id']:multi_embedding.cpu().numpy().squeeze()})

with open('patient_embedding_dict_summed_SCP_structured_w_confidence.pkl', 'wb') as f:
    pickle.dump(patient_embedding_dict, f)




## visualize different text embeddings with different scp codes with UMAP


In [9]:
import seaborn as sns
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from umap import UMAP
import seaborn as sns
import plotly.express as px

sns.set(style='white', context='poster', rc={'figure.figsize':(14,10)})


## flatten the dictionary to get the embeddings
input_batch_embeddings = torch.stack([torch.tensor(embedding).detach() for embedding in cache_dict.values()])
feature_df  = pd.DataFrame()
scp_codes = [key for key in cache_dict.keys()]
print((scp_codes))
feature_df = pd.DataFrame(scp_codes)
# feature_df['scp_codes'] = scp_codes
diagnostic_class = [str(df[df['key']==key]['diagnostic_class'].values[0]) for key in cache_dict.keys()]
feature_df['diagnostic_class'] = diagnostic_class
print((diagnostic_class))
print(len(diagnostic_class))
scp_statements = [df[df['key']==key]['SCP-ECG Statement Description'].values[0] for key in cache_dict.keys()]
statement_descriptions = [df[df['key']==key]['SCP-ECG Statement Description'].values[0] for key in cache_dict.keys()]

#

['NORM', 'IMI', 'AFLT', 'NDT', 'NST_', 'DIG', 'LVH', 'LPFB', 'LNGQT', 'LAFB', 'IRBBB', 'RAO/RAE', 'RVH', 'IVCD', 'LMI', 'ASMI', 'AMI', 'ISCAL', '1AVB', 'ISC_', 'PACE', 'ISCLA', 'SEHYP', 'ISCIL', 'ILMI', 'PVC', 'CRBBB', 'CLBBB', 'ALMI', 'ANEUR', 'ISCAS', 'EL', 'LAO/LAE', 'ILBBB', 'ISCIN', 'AFIB', 'INJAS', 'INJAL', 'IPMI', 'WPW', 'ISCAN', 'INJLA', 'IPLMI', '3AVB', 'PAC', 'INJIL', '2AVB', 'PSVT', 'PMI', 'STACH', 'INJIN', 'BIGU']
['NORM', 'MI', 'nan', 'STTC', 'STTC', 'STTC', 'HYP', 'CD', 'STTC', 'CD', 'CD', 'HYP', 'HYP', 'CD', 'MI', 'MI', 'MI', 'STTC', 'CD', 'STTC', 'nan', 'STTC', 'HYP', 'STTC', 'MI', 'nan', 'CD', 'CD', 'MI', 'STTC', 'STTC', 'STTC', 'HYP', 'CD', 'STTC', 'nan', 'MI', 'MI', 'MI', 'CD', 'STTC', 'MI', 'MI', 'CD', 'nan', 'MI', 'CD', 'nan', 'MI', 'nan', 'MI', 'nan']
52



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [None]:


# proj_2d = umap_2d.fit_transform(input_batch_embeddings.detach().cpu().numpy().squeeze())

sns.set(style='white', context='poster', rc={'figure.figsize':(14,10)})
umap_2d = UMAP(random_state=42)
fig_2d = plt.figure(figsize=(12, 12))
proj_2d = umap_2d.fit_transform(input_batch_embeddings[:,0,:].detach().cpu().numpy())
# df['feature'] = proj_2d.tolist()
feature_df['feature'] = proj_2d.tolist()
fig_2d = px.scatter(proj_2d,x=0,y=1,color = feature_df['diagnostic_class'], 
                    symbol = scp_codes, height=600,width=600)  # O)
fig_2d.show()