In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial import distance
from collections import Counter

from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

In [None]:
#modules to handle icd codes
import sys
!{sys.executable} -m pip install icd10-cm
!{sys.executable} -m pip install simple-icd-10
import icd10
import simple_icd_10 as icd

In [None]:
patients = pd.read_csv('discharge_chapters_simple.csv', index_col = 'index', usecols=['index', 'id', 'Ear', 'Blood/Immune', 'Circulatory', "Abnormal Labs", "Musculoskeletal", "Genitourinary", "Mental", "Eye", "Skin", "Nutritional", "Nervous", "Respiratory", "Digestive"])

for col in ['Ear', 'Blood/Immune', 'Circulatory', "Abnormal Labs", "Musculoskeletal", "Genitourinary", "Mental", "Eye", "Skin", "Nutritional", "Nervous", "Respiratory", "Digestive"]:
    patients[col] = patients[col].apply(lambda x: 1 if x > 0 else 0)
    
patients.head()

In [None]:
#bring in other dataframes and map to see what the main features of each cluster are.
symptoms = pd.read_csv('symptoms_id.csv', index_col = 'index')
general = pd.read_csv('combined_clean.csv', index_col = 'index').fillna(0)

In [None]:
oxygen_therapies = pd.DataFrame(columns = ['OXY: noninvasive', 'OXY: invasive'])
oxygen_therapies['OXY: noninvasive'] = (general['Oxygen therapy - face mask'] + general['Oxygen therapy - high flow'] + general['Noninvasive ventilation']) > 0
oxygen_therapies['OXY: invasive'] = (general['Oxygen therapy - intubation'] + general['Oxygen therapy - ventilator'] + general['Invasive ventilation']) > 0
oxygen_therapies = oxygen_therapies.applymap(lambda x: 1 if x else 0)

In [None]:
length_of_stay = pd.DataFrame(columns = ['id', 'LOS: 0-2weeks', 'LOS: 2-4weeks', 'LOS: 4weeks+'])
length_of_stay.id = general.id
for col in ['LOS: 0-2weeks', 'LOS: 2-4weeks', 'LOS: 4weeks+']:
    length_of_stay[col].values[:] = 0
    
for id_num, length in zip(general.id, general.los):
    if length in list(range(14, 28)):
        length_of_stay.loc[length_of_stay.id == id_num, 'LOS: 2-4weeks'] = 1
    elif length >= 28:
        length_of_stay.loc[length_of_stay.id == id_num, 'LOS: 4weeks+'] = 1
    else:
        length_of_stay.loc[length_of_stay.id == id_num, 'LOS: 0-2weeks'] = 1
        
length_of_stay.drop(columns = 'id', inplace = True)
length_of_stay.head()

In [None]:
general = pd.concat([general, oxygen_therapies, length_of_stay], axis = 1)
symptoms = pd.concat([symptoms, general[['morbidity_Diabetes', 'morbidity_COPD', 'morbidity_Hypertension', 'morbidity_Heart disease', 'morbidity_Renal disease', 'morbidity_Tumor', 'morbidity_Metabolic disorders', 'morbidity_Respiratory diseases']]], axis =1)

In [None]:
demographics = pd.DataFrame(columns=['id', 'male', 'age: 18-49', 'age: 50-64', 'age: 65+'])
demographics.id = general.id
demographics.male = general.Male
for col in ['age: 18-49', 'age: 50-64', 'age: 65+']:
    demographics[col].values[:] = 0
demographics.head()

In [None]:
for id_num, age in zip(general.id, general.age):
    if age in list(range(18, 50)):
        demographics.loc[demographics.id == id_num, 'age: 18-49'] = 1
    elif age in list(range(50, 65)):
        demographics.loc[demographics.id == id_num, 'age: 50-64'] = 1
    elif age >= 65:
        demographics.loc[demographics.id == id_num, 'age: 65+'] = 1

demographics.head()

In [None]:
def getCells(nrows, ncolumns):#make cell dictionary
    cells = {}
    for j in range(nrows):
        for i in range(ncolumns):
            index = i + j*17 + 1
            y = (j+1)*0.8660254
            if j % 2:
                x = i + 1
            else:
                x = i + 1.5
            cells[index] = (x, y)
    x = [x for (x, y) in cells.values()]
    y = [y for (x, y) in cells.values()]
    plt.scatter(x, y)

    return(cells)

In [None]:
def makeClusterDict(cluster_mappings, n_clusters):
    cluster_ids = {}
    for id_num, mapping in zip(patients.id, cluster_mappings):
        if mapping in cluster_ids.keys():
            cluster_ids[mapping].append(id_num)
        else:
            cluster_ids[mapping] = [id_num]
    return cluster_ids

In [None]:
#function to get proportion of people in a cluster with a certain feature/mean value of feature for person in a cluster
def mean_col(cluster_ids, cluster_num, col, df = symptoms):
    #get patients in cluster
    ids = cluster_ids[cluster_num]
    #make a list of the feature value for those patients
    data = []
    for id_num, value in zip(df.id, df[col]):
        if id_num in ids:
            data.append(value)
    
    if data == []:
        return 0
    else:
        return sum(data)/len(data)

In [None]:
def rename_col(name):
    if 'morbidity_' in name:
        newname = 'MOR: ' + name[10:]
        if newname.endswith('disease'):
            newname = newname[:-8]
        elif newname.endswith('diseases'):
            newname = newname[:-9]
        elif newname.endswith('disorders'):
            newname = newname[:-10]
        return newname 
    else: return name

In [None]:
def getClusterInfo(cluster_mappings, n_clusters, cluster_names = [], title = '', return_cluster_ids = False, return_prevalences = False):
    #first make a dictionary of patients assigned to each cluster
    cluster_ids = makeClusterDict(cluster_mappings, n_clusters)
    
    if len(cluster_names) == 0:
        cluster_names = list(cluster_ids.keys())
    
    min_clusters = cluster_mappings.min()
    max_clusters = cluster_mappings.max()
    #Then print a summary of the clusters
    count = 0
    for i in cluster_names:
        print('There are ' + str(len(cluster_ids[i])) + ' patients in cluster ' + str(i) + '.')
        count += len(cluster_ids[i])
    print('There are ' + str(count) + ' patients altogether.')
    

    code_prevalence = pd.DataFrame(index = cluster_names, columns = ['Blood/Immune', 'Circulatory', "Abnormal Labs", "Musculoskeletal", "Genitourinary", "Nutritional", "Nervous", "Respiratory", "Digestive"])

    general_prevalence = pd.DataFrame(index = cluster_names, columns = ['ICU', 'death', 'OXY: noninvasive', 'OXY: invasive', 'ECMO', 'los', 'LOS: 0-2weeks', 'LOS: 2-4weeks', 'LOS: 4weeks+'])
    
    demographic_prevalence = pd.DataFrame(index = cluster_names, columns = ['male', 'age: 18-49', 'age: 50-64', 'age: 65+'])
    
    for i in list(cluster_ids.keys()):
        for col in demographic_prevalence.columns:
            demographic_prevalence.loc[i, col] = mean_col(cluster_ids, i, col, demographics)
        for col in code_prevalence.columns:
            code_prevalence.loc[i, col] = mean_col(cluster_ids, i, col, patients)
        for col in general_prevalence.columns:
            general_prevalence.loc[i, col] = mean_col(cluster_ids, i, col, general)
            
    sns.set(font_scale=1.1)
    #plot useful pictures
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True, figsize=(6, 15), gridspec_kw={'height_ratios':[1, 1, 1]})
    fig.subplots_adjust(hspace=0.35)
    #fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True, figsize=(10, 4), gridspec_kw={'width_ratios': [3, 3, 5]})
    if title:
        fig.suptitle(title)
    else:
        fig.suptitle('Prevalence of binary features in each cluster''s patients')
    
    code_prevalence = code_prevalence.astype(float)
    h1 = sns.heatmap(code_prevalence.T, vmin = 0, vmax = 1, cmap = "magma_r", cbar=False, ax = ax1)
    ax1.set_ylabel('Location of deterioration indicators')
    
    # split axes of heatmap to put colorbar
    ax_divider = make_axes_locatable(h1)
    # define size and padding of axes for colorbar
    cax = ax_divider.append_axes('top', size = '5%', pad = '2%')
    # make colorbar for heatmap. 
    # Heatmap returns an axes obj but you need to get a mappable obj (get_children)
    cb1 = fig.colorbar(h1.get_children()[0], cax = cax, orientation = 'horizontal')
    # locate colorbar ticks
    cax.xaxis.set_ticks_position('top')

    
    general_prevalence = general_prevalence.astype(float)
    sns.heatmap(general_prevalence.drop(columns = ['los']).T, vmin = 0, vmax = 1, cmap = "magma_r", ax = ax2, cbar = False)
    ax2.set_ylabel('Severity of deterioration indicators')
  
                                          
    demographic_prevalence = demographic_prevalence.astype(float)
    sns.heatmap(demographic_prevalence.T, vmin = 0, vmax = 1, cmap = "magma_r", ax = ax3, cbar = False)
    ax3.set_xlabel('Cluster')
    ax3.set_ylabel('Demographic information')
    
    plt.show()
        
      
    #Plot age and length of stay histograms

    if return_cluster_ids and return_prevalences:
        return cluster_ids, demographic_prevalence, code_prevalence, general_prevalence
    elif return_cluster_ids and not return_prevalences:
        return cluster_ids
    elif return_prevalences:
        return demographic_prevalence, code_prevalence, general_prevalence

Get discharge codes:

In [None]:
discharge = pd.read_csv('../data/discharge_diags.csv')

#drop nan
discharge.dropna(axis = 0, subset = ['diag_code'], inplace = True)  
discharge.drop_duplicates(inplace = True)

discharge.head()

In [None]:
def trim_code(code):
    '''Trim code until we get an actual discharge code.'''
    if icd.is_valid_item(code[:5]) or icd10.exists(code[:5]):
        if icd.is_valid_item(code[:6]) or icd10.exists(code[:6]):
            return code[:6]
        else: return code[:5]
    elif icd.is_valid_item(code[:3]) or icd10.exists(code[:3]):
        return code[:3]
    else: return code[:5]

Make dictionary of all the actual code_mappings (some contain two!)

In [None]:
code_mappings = {}
for code in list(discharge.diag_code.unique()):
    if '+' in code:
        [code1, code2] = code.split('+')
        #in this case code2 normally ends with an '*'. Let's get rid
        code2 = code2[:-1]
        short_code1 = trim_code(code1)
        short_code2 = trim_code(code2)
        if code1 not in code_mappings.keys():
            code_mappings[code1] = short_code1
        if code2 not in code_mappings.keys():
            code_mappings[code2] = short_code2
        if code not in code_mappings.keys():
            code_mappings[code] = [short_code1, short_code2]
    else:
        short_code = trim_code(code)
        if code not in code_mappings.keys():
            code_mappings[code] = short_code

In [None]:
#Function to retrieve all the codes from patient's in a given cluster
def getClusterCodes(cluster_ids, cluster_num):
    #get patients in cluster
    ids = cluster_ids[cluster_num]
    #make a list of all the codes those patients have recorded
    codes = []
    for id_num, code in zip(discharge.id, discharge.diag_code):
        if id_num in ids:
            new_codes = code_mappings[code]
            if len(new_codes) == 2:
                codes += new_codes
            else:
                codes.append(new_codes)
            
    return codes

In [None]:
#For each cluster, print a summary of the 5 most common codes (excluding COVID).
def summariseClusterCodes(cluster_ids, cluster_names):
    ClusterCodes = {}
    #iterate through clusters
    for i in cluster_names:
        codes = getClusterCodes(cluster_ids, i)
        counts = Counter(codes)
        ClusterCodes[i] = counts
        
        corona = ('U07.1', counts['U07.1']) in counts.most_common(5)
        
        #Check if coronavirus in top 5
        if corona:
            #need top 6
            top5 = [(code, count) for code, count in counts.most_common(6) if code != 'U07.1']
        else:
            top5 = counts.most_common(5)
        
        #Print a summary
        print('CLUSTER ' + str(i) + ': ' + str(len(cluster_ids[i])) + ' patients')
        for code, count in top5:
            if icd.is_valid_item(code):
                description = icd.get_description(code)
            elif icd10.exists(code):
                description = icd10.find(code).description
            else:
                description = ''
            print(str(count) + ' counts of ' + code + ' - ' + description) 
        print(str(counts['U07.1']) + ' counts of ' + 'U07.1 - COVID-19, virus identified')
        
    return ClusterCodes

# Analyse Clusters

In [None]:
mappings = pd.read_csv('kmodes_clusters.csv')
mappings.drop(columns = 'Unnamed: 0', inplace = True)
mappings.head()

In [None]:
getClusterInfo(mappings.disch_los_6, 6, cluster_names = list(range(6)), title = 'Baseline K-Modes')
#getClusterInfo(mappings.disch_4_sub_clusters, 9, cluster_names = ['0a', '0b', '0c', '1a', '1b', '2a', '2b', '3a', '3b'], title = 'Layered Axes K-Modes')
#getClusterInfo(mappings.clusters_8, 10, cluster_names = list(range(-1, 9)), title = 'Prognosis Space DBSCAN')

In [None]:
#clusterIds = makeClusterDict(mappings.clusters_0_4, 9)
#summariseClusterCodes(clusterIds, list(range(9)))