# Cluster the embeddings

In [None]:
import numpy as np
import pandas as pd

import torch
from transformers import pipeline, AutoModel, AutoConfig, AutoTokenizer

from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import cosine
import os

DMIS Disease Feature Extractor

In [None]:
feature_extraction = pipeline('feature-extraction', model='dmis-lab/biosyn-sapbert-bc5cdr-disease', framework='pt')

Get the variable descriptions

In [None]:
df_vars = pd.read_csv("~/scratch/datasets/yale_new_haven/supplementary_info/variable_descriptions.csv")

In [None]:
disposition_var = {'disposition'}

demographic_vars = {'age', 'gender', 'ethnicity', 'race', 'lang',
       'religion', 'maritalstatus', 'employstatus', 'insurance_status'}

# department name, ESI score, arrival info, and triage vital info
triage_evaluation_vars = {'dep_name', 'esi', 'arrivalmode', 'arrivalmonth', 'arrivalday', 'arrivalhour_bin'}.union({col for col in df_vars.columns if 'triage_vital' in col})

# chief complaint info (only top 200 were included; represents >90% of the complaints)
chief_complaint_vars = {col for col in df_vars['Variable Name'] if "cc_" in col}

# medication info
medication_vars = {col for col in df_vars['Variable Name'] if 'meds_' in col}

hospital_usage_stats_vars = {'previousdispo', 'n_edvisits', 'n_admissions', 'n_surgeries'}

# prior imaging and EKG counts
# chest x-ray, echocardiogram, electrocardiogram (EKG), other x-ray, other ultra-sound, head CT, other CT, MRI, and all other imaging
imaging_ekg_vars = {'cxr_count','echo_count','ekg_count','otherxr_count', 'otherus_count', 'headct_count', 'otherct_count', 'mri_count','otherimg_count'}

# historic vitals include: systolic blood pressure, diastolic blood pressure, pulse, respiratory rate, oxygen saturation, presence of oxygen device, and temperature
historical_vital_vars = {'dbp_last',
 'dbp_max',
 'dbp_median',
 'dbp_min',
 'o2_device_last',
 'o2_device_max',
 'o2_device_median',
 'o2_device_min',
 'pulse_last',
 'pulse_max',
 'pulse_median',
 'pulse_min',
 'resp_last',
 'resp_max',
 'resp_median',
 'resp_min',
 'sbp_last',
 'sbp_max',
 'sbp_median',
 'sbp_min',
 'spo2_last',
 'spo2_max',
 'spo2_median',
 'spo2_min',
 'temp_last',
 'temp_max',
 'temp_median',
 'temp_min'}

curr = disposition_var.union(demographic_vars.union(triage_evaluation_vars.union(chief_complaint_vars.union(medication_vars.union(hospital_usage_stats_vars.union(imaging_ekg_vars.union(historical_vital_vars)))))))

# past medical history
past_medical_hist_vars = {col for col in df_vars['Variable Name'] if col not in curr and "_" not in col and 'previousdispo' not in col}

curr = curr.union(past_medical_hist_vars)

# historical labs ordered by ED (only top 150 comprising of 94% of all orders)
historical_lab_vars = {col for col in df_vars['Variable Name'] if col not in curr}

In [None]:
def cluster(df_tmp, clustering_features, d_thres='0.1', method='complete'):
    df_tmp = df_tmp[df_tmp['Variable Name'].isin(clustering_features)][['Variable Name', 'Description']]
    embeddings = np.array([np.squeeze(feature_extraction(desc)).mean(axis=0) for desc in df_tmp['Description']])
    
    Z = linkage(embeddings, method=method, metric='cosine')
    
    clustering = fcluster(Z, d_thres, criterion='distance')
    df_tmp['Cluster'] = pd.Series(clustering, index=df_tmp.index)
    
    return df_tmp

In [None]:
cluster_features = list(chief_complaint_vars)

In [None]:
c = cluster(df_vars, cluster_features, d_thres=0.2).sort_values(by=['Cluster', 'Description'])

In [None]:
clusters = {}
for i, row in c[['Description', 'Cluster']].iterrows():
    if row['Cluster'] in clusters:
        clusters[row['Cluster']].append(row['Description'])
    else:
        clusters[row['Cluster']] = [row['Description']]

In [None]:
for i in clusters:
    print(i)
    for j in clusters[i]:
        print(f"\t{j}")
    print()

In [None]:
df = df_vars[df_vars['Variable Name'].isin(cluster_features)][['Variable Name', 'Description']]

In [None]:
cc_embeddings = np.array([np.squeeze(feature_extraction(cc)).mean(axis=0) for cc in df_cc['Description']])

In [None]:
Z = linkage(cc_embeddings, method='complete', metric='cosine')

In [None]:
d_thres = 0.15

In [None]:
clustering = fcluster(Z, d_thres, criterion='distance')

In [None]:
df_cc['Cluster'] = pd.Series(clustering, index=df_cc.index)

In [None]:
cluster_dict = {i: [] for i in df_cc['Cluster'].sort_values().unique()}
for i, row in df_cc.iterrows():
    cluster_dict[row['Cluster']].append(row['Description'])

In [None]:
for i in cluster_dict:
    print(f"{i}) {cluster_dict[i]}")
    print()

In [None]:
df = pd.read_csv("~/scratch/datasets/yale_new_haven/training_test_sets/second_split/features/regression_nn/yale_new_haven_test_features.csv")

In [None]:
cc_vars = [col for col in df.columns if "cc_" in col]

In [None]:
df_vars = pd.read_csv("~/scratch/datasets/yale_new_haven/supplementary_info/variable_descriptions.csv")

In [None]:
var_desc_dict = dict()
for i, row in df_vars[['Variable Name', 'Description']][df_vars['Variable Name'].isin(cc_vars)].iterrows():
    var_desc_dict[row['Variable Name']] = row['Description']

In [None]:
problems = df[cc_vars].T.apply(lambda x: [var_desc_dict[i] for i in x[x == 1].index])

In [None]:
# some people have no problems... What to do about them...?
# zeros will kill all the weights I guess...

In [None]:
problems

In [None]:
kk = problems[problems.apply(lambda x: len(x)) > 1].index

In [None]:
problems[kk[1]]

In [None]:
ep = problems[:100].apply(lambda x: np.array([np.squeeze(feature_extraction(problem)).mean(axis=0) for problem in x]).mean(axis=0))

In [None]:
# storing the embeddings of all of the complaints is going to take A LOT of memory... 
# ~200,000 x 768 floats just for the complaints (non-zero...)
np.squeeze(feature_extraction('HIV')).dtype

In [None]:
np.array(l).mean(axis=0).shape