In [None]:
import pandas as pd
import os
from datetime import datetime
import seaborn as sns
from pyhealth.datasets import MIMIC3Dataset
from datetime import datetime
import plotly.express as px
import matplotlib.pyplot as plt
import numpy as np 
from sklearn.decomposition import PCA
import plotly.graph_objects as go
from collections import Counter
from plotly.subplots import make_subplots
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
embs = r"PATH_TO_EMBEDDINGS.txt" # from choi, 2016

In [None]:
with open(embs, 'r') as file:
    lines = file.readlines()

len(lines) # 51327 tokens + header

In [None]:
codes = []
embeddings = []

for i in range(1, len(lines)):
    line = lines[i].strip()
    parts = line.split()
    code = parts[0]
    embedding = np.array(parts[1:], dtype=float).tolist() 
    codes.append(code)
    embeddings.append(embedding)

emb_df = pd.DataFrame({
    'code': codes,
    'embedding': embeddings
})

emb_df = emb_df.drop(emb_df.index[0]) # dropping <\s> token
print(emb_df.head(), emb_df.shape)

In [None]:
emb_df['prefix'] = emb_df['code'].apply(lambda x: x.split('_')[0])
code_counts = emb_df['prefix'].value_counts()
print(code_counts)

In [None]:
emb_array = np.array(emb_df['embedding'].tolist())
pca = PCA(n_components=0.95)
pca_emb = pca.fit_transform(emb_array)
pca_emb.shape

In [None]:
pca_df = pd.DataFrame(pca_emb)
pca_df['Prefix'] = emb_df['prefix'].values 

In [None]:
fig = px.scatter(pca_df, x=pca_df.columns[0], y=pca_df.columns[1], color=pca_df['Prefix'], 
                 title="PCA: Medical concept embeddings",
                 width=900, height=700)

fig.show()

In [None]:
code_mapping = {'C': '#636EFA',
                'L': '#EF553B',
                'IDX': '#00CC96',
                'N': '#AB63FA',
                'IPR': '#FFA15A'}
colors = pca_df['Prefix'].map(code_mapping)

fig = go.Figure(data=[go.Scatter3d(
    x=pca_df[pca_df.columns[0]], 
    y=pca_df[pca_df.columns[1]],  
    z=pca_df[pca_df.columns[2]],
    mode='markers',
    marker=dict(
        size=2,
        color=colors,          
        opacity=0.8
    ),
    text=pca_df['Prefix'], 
    hoverinfo='text'
)])

fig.update_layout(
    title="Embeddings PCA, Learning low-dimensional representations of medical concepts, Choi 2016",
    margin=dict(l=0, r=0, b=0, t=30)
)
fig.show()

## Types of codes
We have 5 different kinds of codes in the embedding set. In this work, we will be using the first three.
- IDX = ICD-9 Diagnosis codes
- N = NDC Medication codes
- IPR = ICD-9 Procedure codes
- C = CPT Codes (CPTEVENTS table)
- L = LOINC Codes (Lab results) 

In [None]:
code_counts_df = pd.DataFrame(code_counts)

fig = px.bar(x=code_counts_df.index, y=code_counts_df['prefix'], 
             title='Types of codes in the embedding set',
             labels={'prefix': 'Frequency', 'index': 'Code prefix'},
             width=800, height=500)

fig.update_xaxes(tickangle=45) 
fig.update_traces(text=code_counts_df['prefix'], textposition='outside')
fig.show()

In [None]:
# dropping codes we won't use
emb_df = emb_df[emb_df['prefix'].isin(["IDX", "N", "IPR"])]
emb_df.shape

## Preprocessing: creating the 'medical codes for HF prediction' dataset

This dataset is comprised of patients as rows and a representation of their medical histories, based on the mean of their embeddings of diagnosis and procedures (ICD-9), and medication (NDC) codes across visits, as features. 

Given the goal of diagnosing heart failure (HF), we will be labelling them later on.  

After, these are concatenated with some demographics data.

#### Creating the patient dictionary with PyHealth

In [None]:
# path where mimic-iii tables are located
path_to_mimiciii_tables = r"PATH_TO_TABLES"

# tables from mimic-iii that will be used - pyhealth supports diagnoses, procedures, prescriptions
tables = ["DIAGNOSES_ICD", 'PROCEDURES_ICD', "PRESCRIPTIONS"]

#codes from new york state dept health icd workbook
# list of icd-9 codes indicating heart failure - strings due to dict formatting
HF_ICD9_CODES = [
    '4250', '4251', '4252', '4253', '4254', '4255', '4257', '4258', '4259', '42983',
    '42830', '42832', '42840', '42842',
    '39891', '4280', '4281', '42820', '42822', '4289',
    '40211', '40291',
    '40411', '40491',
    '4168', '4169'
]

In [None]:
# loading mimic-iii dataset with pyhealth
def load_mimic_pyhealth(path):
    base_dataset = MIMIC3Dataset(
        root=path,
        tables=tables,
        dev=False,
        #code_mapping={"NDC": "ATC"},
        refresh_cache=False)
    return base_dataset

mimic_data = load_mimic_pyhealth(path_to_mimiciii_tables)
print(mimic_data)

In [None]:
# mimic-iii dataset is loaded as a dict -> key is patient_id; value is Patient object
print(mimic_data.info())
print(f"\nAvailable MIMIC-III tables: {mimic_data.available_tables}")

In [None]:
mimic_data.stat()

In [None]:
# filter patients with more than 2 visits
filtered_patients = {patient_id: patient for patient_id, patient in mimic_data.patients.items() if len(patient.visits) >= 2}
print(f"The number of patients with >=2 visits is {len(filtered_patients)}.")

In [None]:
visit_counts = {patient_id: len(patient.visits) for patient_id, patient in filtered_patients.items()} # run again after filtering max 5 visits for seeing graph below
frequency = Counter(visit_counts.values())
sorted_visits = dict(sorted(frequency.items()))
df_visits = pd.DataFrame(sorted_visits.items(), columns=['visits', 'frequency'])

In [None]:
total_patients = sum(frequency.values())
cumulative_frequency = 0
cutoff = None

for visits, count in sorted(frequency.items()):
    cumulative_frequency += count
    if cumulative_frequency / total_patients >= 0.95:  # 95% threshold
        cutoff = visits
        break

print(f"Suggested cutoff based on 95% cumulative frequency: {cutoff} visits")

In [None]:
# filter patients with 2-5 visits
filtered_patients = {patient_id: patient for patient_id, patient in filtered_patients.items() if len(patient.visits) <= 5}

In [None]:
plt.figure(figsize=(10, 6))
plt.bar(df_visits['visits'], df_visits['frequency'])

plt.title('Frequency of patient visits')
plt.xlabel('No. of visits')
plt.ylabel('No. of patients')
plt.xticks(rotation=45)
for bar in plt.gca().patches:
    plt.gca().text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.1, f'{int(bar.get_height())}',
                   ha='center', va='bottom', color='black', rotation='horizontal')
plt.show()

## Creating the patient dataset - codes
First, we will create some functions to label patients based on the presence or absence of an ICD-9 heart failure code in their medical history. 
In the end, we will have a dataset where each row is a patient visit, and the columns are lists of their diagnosis + procedure + medication codes (these for all of their medical history) + label (binary). 

Then, we will get the embeddings per code class. And then we take the mean among all to get the final patient representation. In the end, each row will be representing a patient's medical history, and in the columns we will find their mean embedding (dim=300).

In [None]:
# dict structure
patient = filtered_patients['10004']
visit = patient.visits['161106']
d = visit.get_event_list('DIAGNOSES_ICD') # events - diag
print(d[0]) # single event
print(d[0].code, type(d[0].code)) #diag code - str

In [None]:
def diagnosis_stats(codes, patients, diagnosis):
    patients_with_diagnosis = set()
    total_patient_count = len(filtered_patients)

    for patient in patients.values():
        # flag to check if a heart failure code is found for a patient
        has_diagnosis = False
        
        for visit in patient.visits.values():
            if 'DIAGNOSES_ICD' in visit.available_tables:
                for event in visit.get_event_list('DIAGNOSES_ICD'):
                    if event.code in codes:
                        has_diagnosis = True
                        break  # exit the loop when a heart failure code is found
        
        # if a heart failure code was found for the patient, increment the counter
        if has_diagnosis:
            patients_with_diagnosis.add(patient.patient_id)

    pt_diagnosis_count = len(patients_with_diagnosis)

    print(f"Total number of patients: {total_patient_count}")
    print(f"Number of patients with at least one {diagnosis} code: {pt_diagnosis_count}")

    if total_patient_count > 0:  
        percentage_patients = (pt_diagnosis_count / total_patient_count) * 100
        print(f"Percentage of patients with at least one {diagnosis} code: {percentage_patients:.2f}%")
    else:
        print("no patients found")

diagnosis_stats(HF_ICD9_CODES, filtered_patients, "heart failure")

In [None]:
def filter_visits(patient, codes):
    has_diagnosis = False
    first_hf_time = None
    
    for visit in patient.visits.values():
        for event in visit.get_event_list('DIAGNOSES_ICD'):
            if event.code in codes:
                #print("HF code found in visit:", visit, visit.encounter_time)
                has_diagnosis = True
                if first_hf_time is None or visit.encounter_time < first_hf_time:
                    first_hf_time = visit.encounter_time
                    break  
    
    if has_diagnosis and first_hf_time is not None:
        visits_to_include = {
            visit_id: visit
            for visit_id, visit in patient.visits.items()
            if visit.encounter_time < first_hf_time
        }
    else:
        visits_to_include = {
            visit_id: visit
            for visit_id, visit in patient.visits.items()}
    
    return has_diagnosis, visits_to_include

In [None]:
code_first_visit = 0
total_patients = 0
positives = 0
without_hf_first_visit = []

for patient_id, patient in filtered_patients.items():
    #if total_patients_processed >= 20000:
        #break
    has_diagnosis, visits_to_include = filter_visits(patient, HF_ICD9_CODES)
    if has_diagnosis:
        positives += 1
        
        first_visit = min(patient.visits.values(), key=lambda visit: visit.encounter_time)
        if first_visit:
            first_visit_has_code = any(
                event.code in HF_ICD9_CODES
                for event in first_visit.get_event_list('DIAGNOSES_ICD')
            )
            if first_visit_has_code:
                code_first_visit += 1
            else:
                without_hf_first_visit.append(patient_id)
                
    total_patients += 1

if positives > 0:
    per_first_visit = (code_first_visit / positives) * 100
    print(f"No. positive patients: {positives}")
    print(f"Percentage of patients with HF who had a HF code in their first visit: {per_first_visit:.2f}%")
else:
    print("no patients found with HF diagnosis codes.")

if without_hf_first_visit:
    print(f"Patients with/without HF in first visit:{code_first_visit, len(without_hf_first_visit)}")
    #print(without_hf_first_visit)
else:
    print("all patients have HF in their first visit.")

In [None]:
diag, vis = filter_visits(filtered_patients['10160'], HF_ICD9_CODES) # positive but with code in 1st visit = empty dict. these patients will be dropped.
diag, vis

In [None]:
diag, vis = filter_visits(filtered_patients['10174'], HF_ICD9_CODES) # positive wwith code not in 1st visit = visits after hf included
diag, vis

In [None]:
diag, vis = filter_visits(filtered_patients['3868'], HF_ICD9_CODES) # negative = all visits included
diag, vis

In [None]:
def create_record(patient, visit, has_diagnosis):
    admission_time = visit.encounter_time if isinstance(visit.encounter_time, datetime) else datetime.strptime(visit.encounter_time, "%Y-%m-%d %H:%M:%S")
    birthdate = patient.birth_datetime if isinstance(patient.birth_datetime, datetime) else datetime.strptime(patient.birth_datetime, "%Y-%m-%d")
        
    # calculate age at the time of admission
    age_at_admission = (admission_time - birthdate).days // 365

    diagnosis_codes = [event.code for event in visit.get_event_list('DIAGNOSES_ICD')]
    procedure_codes = [event.code for event in visit.get_event_list('PROCEDURES_ICD')]
    prescription_codes = [event.code for event in visit.get_event_list('PRESCRIPTIONS') if event.code != '0'] #ignore code 0 ndc
        
    record = {
        'patient_id': patient.patient_id,
        'label': int(has_diagnosis),
        'visit_id': visit.visit_id,
        'gender': patient.gender,
        'ethnicity': patient.ethnicity,
        'age': age_at_admission,  
        'admission_time': admission_time,
        'discharge_time': visit.discharge_time,
        'discharge_status': visit.discharge_status,
        'diagnosis_codes': ', '.join(diagnosis_codes),
        'procedure_codes': ', '.join(procedure_codes),
        'prescription_codes': ', '.join(prescription_codes),            
    }
    
    return record

In [None]:
def generate_patient_visits(filtered_patients, codes):
    data = []
    #visit_ids = []
    for patient_id, patient in filtered_patients.items():
        has_diagnosis, visits_to_include = filter_visits(patient, codes)
        if visits_to_include != {}: # for positive patients, only include patients who didnt have HF in 1st visit
            for visit_id, visit in visits_to_include.items():
                record = create_record(patient, visit, has_diagnosis)
                data.append(record)
            #visit_ids.append(visit_id)
    
    #with open('visit_ids_to_include.txt', 'w') as file:
        #for visit_id in visit_ids:
            #file.write(f"{visit_id}\n")
    
    return pd.DataFrame(data)

In [None]:
hf_patient_seqs = generate_patient_visits(filtered_patients, HF_ICD9_CODES)

In [None]:
print(hf_patient_seqs['patient_id'].nunique())
hf_patient_seqs

In [None]:
# check - only keeping positive patients with HF in 2nd visit or later
# first - hf in first visit; 2nd = hf in later visits, 3rd = negative
print(hf_patient_seqs['patient_id'].isin(['10160']).any(), hf_patient_seqs['patient_id'].isin(['10174']).any(), hf_patient_seqs['patient_id'].isin(['10059']).any())
print(hf_patient_seqs.loc[hf_patient_seqs['patient_id'] == '10160', 'label'].values, hf_patient_seqs.loc[hf_patient_seqs['patient_id'] == '10174', 'label'].values, hf_patient_seqs.loc[hf_patient_seqs['patient_id'] == '10059', 'label'].values) # 1 label per each visit

In [None]:
positive_count = hf_patient_seqs[hf_patient_seqs['label'] == 1].shape[0]
negative_count = hf_patient_seqs[hf_patient_seqs['label'] == 0].shape[0]

pt_data = pd.DataFrame({
    'Label': ['Positive', 'Negative'],
    'Count': [positive_count, negative_count]
})

fig = px.bar(pt_data, x='Label', y='Count',
             width=700, height=600,
             labels={'Count': 'No. of patients', 'Label': 'Heart failure diagnosis'},
             title='Positive/Negative HF visits')
fig.show()

## Creating the patient dataset - embeddings
Now, we will start converting the list of codes into their embedding representation.

In [None]:
def format_embs(df, prefix):
    df_prefix = df[df['prefix']==prefix]['code']
    df_prefix = df_prefix.apply(lambda x: x.replace(".", "")) # remove dots from codes to match mimic table
    df_prefix = df_prefix.apply(lambda x: x.split('_', 1)[-1]) # remove the prefix from the code
    return df_prefix

In [None]:
# obtaining sets of each code class that are being used by patients
diag_df = format_embs(emb_df, 'IDX')
med_df = format_embs(emb_df, 'N')
proc_df = format_embs(emb_df, 'IPR')

diag_codes_patient = set(code for sublist in hf_patient_seqs['diagnosis_codes'].str.split(', ') for code in sublist)
med_codes_patient = set(code for sublist in hf_patient_seqs['prescription_codes'].str.split(', ') for code in sublist)
proc_codes_patient = set(code for sublist in hf_patient_seqs['procedure_codes'].str.split(', ') for code in sublist)

In [None]:
def format_embs_all(df):
    df['code'] = df['code'].apply(lambda x: x.replace(".", ""))
    df['code'] = df['code'].apply(lambda x: x.split('_', 1)[-1])
    return df

# removing prefixes from embeddings in final df
emb_df_wo_prefix = format_embs_all(emb_df)
emb_df_wo_prefix = emb_df_wo_prefix[emb_df_wo_prefix['prefix'].isin(['IDX', 'N', 'IPR'])]
emb_df_wo_prefix

In [None]:
# dict to store code-emb mapping
code_to_embedding = {f"{row['prefix']}_{row['code']}": row['embedding'] for index, row in emb_df_wo_prefix.iterrows()}
#code_to_embedding 

In [None]:
# taking the mean of the embeddings of each code class
def add_mean_emb(df, type):
    prefix = {'diagnosis': 'IDX_',
              'prescription': 'N_',
              'procedure': 'IPR_'}
    
    df[type + '_embedding'] = None

    for index, row in df.iterrows():
        embeddings = []
        for code in str(row[type + '_codes']).split(','):
            code = code.strip()
            code = prefix[type] + code
            if code in code_to_embedding:
                embeddings.append(code_to_embedding[code])
        
        if embeddings:
            mean_embedding = np.mean(embeddings, axis=0)
            df.at[index, type + '_embedding'] = mean_embedding

    return df

In [None]:
hf_patient_seqs_emb = add_mean_emb(hf_patient_seqs, 'diagnosis')
hf_patient_seqs_emb = add_mean_emb(hf_patient_seqs_emb, 'prescription')
hf_patient_seqs_emb = add_mean_emb(hf_patient_seqs_emb, 'procedure')
hf_patient_seqs_emb = hf_patient_seqs_emb.dropna() # dropping patients that do not have all 3 kinds of embs (~200)
hf_patient_seqs_emb, hf_patient_seqs_emb.shape

In [None]:
#for scatterplot
hf_patient_seqs_emb = hf_patient_seqs_emb[hf_patient_seqs_emb['age'] <= 98]


In [None]:
print(hf_patient_seqs_emb['patient_id'].nunique())

In [None]:
# to represent a single vsit, take mean of mean we just took
def add_mean_emb_visit(df):
    df['visit_embedding'] = None
    embedding_cols = [col for col in df.columns if '_embedding' in col]

    for index, row in df.iterrows():
        embeddings = [row[col] for col in embedding_cols if isinstance(row[col], np.ndarray) and row[col].size > 0]
        
        if embeddings:
            mean_embedding = np.mean(embeddings, axis=0)
            df.at[index, 'visit_embedding'] = mean_embedding

    return df

In [None]:
hf_patient_seqs_emb_mean = add_mean_emb_visit(hf_patient_seqs_emb)
hf_patient_seqs_emb_mean

In [None]:
print(hf_patient_seqs_emb_mean.isna().sum()) # no missing values

In [None]:
hf_patient_seqs_emb_mean.head()

In [None]:
visit_embeddings = np.vstack(hf_patient_seqs_emb_mean['visit_embedding'])
labels = hf_patient_seqs_emb_mean['label'].values
pca = PCA(n_components=0.95)
pca_visit_embeddings = pca.fit_transform(visit_embeddings)

In [None]:
pca_visit_embeddings.shape

In [None]:
plt.figure(figsize=(10, 8))
sns.scatterplot(x=pca_visit_embeddings[:, 0], y=pca_visit_embeddings[:, 1], hue=labels, legend='full', alpha=0.7)
plt.title('PCA of visit embeddings')
plt.xlabel('c1')
plt.ylabel('c2')
plt.legend(title='Label')
plt.grid(True)
plt.show()

In [None]:
# the final step, create new df with aggregated visit embedding data per patient_id
def avg_patient_emb(df):
    patient_data = [] 
    grouped = df.groupby('patient_id')

    for patient_id, group in grouped:
        embeddings = [emb for emb in group['visit_embedding']]

        if embeddings:
            mean_embedding = np.mean(embeddings, axis=0)
            label = group['label'].max()

            # optional - demographics
            #gender = group['gender'].mode()[0]
            #ethnicity = group['ethnicity'].mode()[0]
            #age = group['age'].mean()
            #discharge_status = group['discharge_status'].mean()
            #visit_length = group['visit_length'].mean()

            patient_data.append({
                'patient_id': patient_id,
                'label': label,
                'patient_embedding': mean_embedding,
                #'gender': int(gender),
                #'ethnicity': ethnicity,
                #'age': age,
                #'discharge_status': discharge_status,
                #'visit_length': visit_length
            })

    patient_df = pd.DataFrame(patient_data, columns=['patient_id', 'label', 'patient_embedding']) # "gender", "ethnicity", "age", "discharge_status", "visit_length" for demographics
    return patient_df

In [None]:
patient_emb_df = avg_patient_emb(hf_patient_seqs_emb_mean)
patient_emb_df

In [None]:
print(patient_emb_df.isna().sum()) # no missing values

In [None]:
positive_count = patient_emb_df[patient_emb_df['label'] == 1].shape[0]
negative_count = patient_emb_df[patient_emb_df['label'] == 0].shape[0]

pt_data = pd.DataFrame({
    'Label': ['Positive', 'Negative'],
    'Count': [positive_count, negative_count]
})

fig = px.bar(pt_data, x='Label', y='Count',
             width=700, height=600,
             labels={'Count': 'No. of patients', 'Label': 'Heart failure diagnosis'},
             title='Incidence of heart failure in the patient cohort')
fig.update_xaxes(tickangle=45) 
fig.update_traces(text=pt_data['Count'], textposition='outside')
fig.show()




In [None]:
""" # for demographics
patient_embs = np.vstack(
    patient_emb_df.apply(lambda row: np.concatenate((row['patient_embedding'], 
                                                 [row['visit_length'], 
                                                  row['gender'], 
                                                  row['ethnicity'], 
                                                  row['age']])), axis=1)
) """

In [None]:
patient_embs = np.vstack(patient_emb_df['patient_embedding'])
pca = PCA(n_components=0.95)
pca_patient_embs = pca.fit_transform(patient_embs)
print(pca_patient_embs.shape)

code_mapping = {1: '#e8e337', 0: '#636EFA'}
colors = patient_emb_df['label'].map(code_mapping)

fig = go.Figure(data=go.Scatter(
    x=pca_patient_embs[:, 0],
    y=pca_patient_embs[:, 1],
    mode='markers',
    marker=dict(
        color=colors, 
        size=10,
        opacity=0.8, 
    ),
    showlegend=False, 
))

fig.add_trace(go.Scatter(
    x=[None],
    y=[None],
    mode='markers',
    marker=dict(
        color='#e8e337',
        size=10,
    ),
    name='Positive (1)'
))

fig.add_trace(go.Scatter(
    x=[None],
    y=[None],
    mode='markers',
    marker=dict(
        color='#636EFA',
        size=10,
    ),
    name='Negative (0)'
))

fig.update_layout(
    title="PCA: Patient Embeddings",
    xaxis_title="Component 1",
    yaxis_title="Component 2",
    width=900,
    height=700
)
fig.show()


In [None]:
""" patient_emb_df.to_csv(r"PATH/patient_emb_df.csv")

patient_emb_cols = pd.DataFrame(patient_embs, columns=[f'dim_{i}' for i in range(patient_embs.shape[1])]) 
patient_emb_cols.to_csv(r"PATH/patient_emb_cols.csv")

cosine_sim_matrix = cosine_similarity(patient_embs)
cosine_sim_df = pd.DataFrame(cosine_sim_matrix, index=range(len(patient_embs)), columns=range(len(patient_embs))) 
cosine_sim_df.to_csv(r"PATH/cosine_sim_df.csv") """