In [1]:
import os
import csv
import umap
import numpy as np
import pandas as pd
import matplotlib
from bokeh.plotting import figure, output_file, save, ColumnDataSource, show, save, output_notebook
from bokeh.models import CategoricalColorMapper, HoverTool
from bokeh.io import export_svgs



In [2]:
DATA_PATH = os.path.expanduser('~/data1/ehr-stratification/data')
data_folder = 'ehr-804371-test-2'
dm_file = 'patient-details.csv'
# enc_file_tr = 'encodings/TRconvae-avg_vect.csv'
enc_file = 'encodings/convae-avg_vect_scaled.csv'
cl_lab_snomed = 'encodings/cl-subsampling/outer-cl-convae-snomed-it67.txt'
# cl_lab_ccs = 'encodings/outer-cl-convae-ccs-single.txt'

In [3]:
ds_snomed = 'snomed_subsampling/patient-5000-disease-snomed-it67.csv'
# ds_ccs = 'patient-5000-disease-ccs-single.csv'

In [4]:
# uniform color for external validation with respect to purity
# cl_pur_order = [4, 0, 7, 3, 6, 2, 1, 5]
cl_pur_order = ['AD (0.30)', 'ADHD (0.81)', 'BC (0.68)', 
                'CD (0.20)', 'MM (0.41)', 'PC (0.47)', 
                'PD (0.50)', 'T2D (0.61)']
pred_dict = {0: 'ADHD (0.81)',
            1: 'PD (0.50)',
            2: 'PC (0.47)',
            3: 'CD (0.20)',
            4: 'AD (0.30)',
            5: 'T2D (0.61)',
            6: 'MM (0.41)',
            7: 'BC (0.68)'}

In [5]:
col_dict = matplotlib.colors.CSS4_COLORS
c_out = ['bisque', 'mintcream', 'cornsilk', 'lavenderblush', 'aliceblue', 'antiquewhite', 'aqua', 'aquamarine', 'azure',
         'beige', 'powderblue', 'floralwhite', 'ghostwhite', 'lightcoral', 'lightcyan', 'lightgoldenrodyellow',
         'lightgray', 'lightgreen', 'lightgrey', 'lightpink', 'lightsalmon', 'lightseagreen', 'lightskyblue',
         'lightslategray', 'lightslategrey', 'lightsteelblue', 'lightyellow', 'linen', 'palegoldenrod', 'palegreen',
         'paleturquoise', 'palevioletred', 'papayawhip', 'peachpuff', 'mistyrose', 'lemonchiffon', 'lightblue',
         'seashell', 'white', 'blanchedalmond', 'oldlace', 'moccasin', 'snow', 'darkgray', 'ivory', 'whitesmoke']

### Disease class visualization 

In [19]:
def disease_viz(ds_dic, 
              enc_data, 
              mrn=None):
    
    # snomed_dct = {mrn: dis}
    with open(os.path.join(DATA_PATH, ds_dic)) as f:
        rd = csv.reader(f)
        next(rd)
        snomed_dct = {r[0]: r[1] for r in rd}
    # mrn = list(), convae_mtx = list()
    if mrn is None:
        with open(os.path.join(DATA_PATH, data_folder, enc_data)) as f:
            rd = csv.reader(f)
            next(rd)
            mrn = []
            convae_mtx = []
            for r in rd:
                if r[0] in snomed_dct:
                    mrn.append(r[0])
                    convae_mtx.append(r[1::])
    
    # dis_pt = list of disease ordered wrt mrn
    dis_pt = [snomed_dct[m] for m in mrn]
    unique, counts = np.unique(dis_pt, return_counts=True)
    for a, b in dict(zip(unique, counts)).items():
        print(a, b)

    umap_tr = umap.UMAP(random_state=123, n_neighbors=20, min_dist=0.0)
    umap_mtx = umap_tr.fit_transform(convae_mtx)

    print(len(umap_mtx), len(dis_pt), umap_mtx.shape)
    
    plotdata = []
    list_mrn = []
    for dis in sorted(list(set(dis_pt))):
        data = []
        m_vect = []
        for idx, m in enumerate(mrn):
            if dis_pt[idx] == dis:
                data.append(umap_mtx[idx])
                m_vect.append(m)
        plotdata.append(np.array(data))
        list_mrn.append(m_vect)
    
    
    dict_lab = {dis: n for n, dis in enumerate(sorted(list(set(dis_pt))))}
    colormap = [c for c in col_dict if c not in c_out]
#     colormap_rid = [colormap[dict_lab[dis]] for dis in sorted(list(set(dis_pt)))]
    colormap_rid = ['seagreen', 'tan', 'yellow', 'olivedrab', 
                    'slateblue', 'hotpink', 'tomato', 'pink']

    
    newdata = zip([l for l in list_mrn], 
                  [plotdata[idx] for idx in range(len(plotdata))],
                  sorted(list(set(dis_pt))), colormap_rid)
    
    TOOLTIPS = [('mrn', '@mrn'),
                ('ds_class', '@ds_class')]
    plotTools = 'box_zoom, wheel_zoom, pan,  crosshair, reset, save'
    output_notebook()
    p = figure(plot_width=1000, plot_height=800, tools=plotTools)
    
    for mrn, data, name, color in newdata:
        df_dict = {'mrn': mrn, 'x': data[:,0].tolist(), 
                   'y': data[:,1].tolist(), 'ds_class': name}
        df = pd.DataFrame(df_dict).sort_values('ds_class')
        source = ColumnDataSource(dict(
        x=df['x'].tolist(),
        y=df['y'].tolist(),
        mrn=df['mrn'].tolist(),
        ds_class=[str(i) for i in df['ds_class'].tolist()]))

        labels = [str(i) for i in df['ds_class'].tolist()]
        p.add_tools(HoverTool(tooltips=TOOLTIPS))
        p.circle('x', 'y', legend='ds_class', 
                 source=source, 
                 color=color,
                 muted_color=color, 
                 muted_alpha=0.07)
        
    p.legend.click_policy="hide"
    p.xaxis.major_tick_line_color = None
    p.xaxis.minor_tick_line_color = None
    p.yaxis.major_tick_line_color = None
    p.yaxis.minor_tick_line_color = None
    p.xaxis.major_label_text_color = None
    p.yaxis.major_label_text_color = None
    p.grid.grid_line_color = None    
#     output_file(os.path.join(DATA_PATH, data_folder, 
#                              'disease-viz-test-{0}.html'.format(data_folder.split('-')[3])))
#     save(p)
    show(p)
#     p.output_backend = "svg"
#     export_svgs(p, filename=os.path.join(DATA_PATH, data_folder, 
#                              'disease-viz-test-{0}.svg'.format(data_folder.split('-')[3])))

### Outer validation visualization

In [20]:
def outer_viz(ds_dic,
              enc_data, 
              cl_lab, 
              mrn=None):
    # snomed_dct = {mrn: dis}
    with open(os.path.join(DATA_PATH, ds_dic)) as f:
        rd = csv.reader(f)
        next(rd)
        snomed_dct = {r[0]: r[1] for r in rd}
    # mrn = list(), convae_mtx = list()
    if mrn is None:
        with open(os.path.join(DATA_PATH, data_folder, enc_data)) as f:
            rd = csv.reader(f)
            next(rd)
            mrn = []
            convae_mtx = []
            for r in rd:
                if r[0] in snomed_dct:
                    mrn.append(r[0])
                    convae_mtx.append(r[1::])
                    
    # pred_cl = list of predicted classes
    with open(os.path.join(DATA_PATH, data_folder, cl_lab)) as f:
        pred_cl = f.read().splitlines()

    unique, counts = np.unique(pred_cl, return_counts=True)
    for a, b in dict(zip(unique, counts)).items():
        print(a, b)

    umap_tr = umap.UMAP(random_state=123, n_neighbors=20, min_dist=0.0)
    umap_mtx = umap_tr.fit_transform(convae_mtx)

    print(len(umap_mtx), len(pred_cl), umap_mtx.shape)
    
    plotdata = []
    list_mrn = []
#     for cl in sorted(list(set(pred_cl))):
    for cl in cl_pur_order:
        data = []
        m_vect = []
        for idx, m in enumerate(mrn):
            if pred_dict[int(pred_cl[idx])] == str(cl):
                data.append(umap_mtx[idx])
                m_vect.append(m)
        plotdata.append(np.array(data))
        list_mrn.append(m_vect)
    
    dict_lab = {dis: n for n, dis in enumerate(sorted(list(set(pred_cl))))}
    colormap = [c for c in col_dict if c not in c_out]
#     colormap_rid = [colormap[dict_lab[dis]] for dis in sorted(list(set(pred_cl)))]
    colormap_rid = ['seagreen', 'tan', 'yellow', 'olivedrab', 
                    'slateblue', 'hotpink', 'tomato', 'pink']

    
#     newdata = zip([l for l in list_mrn], 
#                   [plotdata[idx] for idx in range(len(plotdata))],
#                   sorted(list(set(pred_cl))), colormap_rid)
    newdata = zip([l for l in list_mrn], 
                  [plotdata[idx] for idx in range(len(plotdata))],
                  cl_pur_order, colormap_rid)
    
    
    TOOLTIPS = [('mrn', '@mrn'),
                ('ds_class', '@ds_class')]
    plotTools = 'box_zoom, wheel_zoom, pan,  crosshair, reset, save'
    output_notebook()
    p = figure(plot_width=1000, plot_height=800, tools=plotTools)
    
    for mrn, data, name, color in newdata:
        df_dict = {'mrn': mrn, 'x': data[:,0].tolist(), 
                   'y': data[:,1].tolist(), 'ds_class': name}
        df = pd.DataFrame(df_dict).sort_values('ds_class')
        source = ColumnDataSource(dict(
        x=df['x'].tolist(),
        y=df['y'].tolist(),
        mrn=df['mrn'].tolist(),
        ds_class=[str(i) for i in df['ds_class'].tolist()]))

        labels = [str(i) for i in df['ds_class'].tolist()]
        p.add_tools(HoverTool(tooltips=TOOLTIPS))
        p.circle('x', 'y', legend='ds_class', 
                 source=source, 
                 color=color,
                 muted_color=color, 
                 muted_alpha=0.07)
        
    p.legend.click_policy="hide"
    p.legend.location = 'bottom_right'
    p.xaxis.major_tick_line_color = None
    p.xaxis.minor_tick_line_color = None
    p.yaxis.major_tick_line_color = None
    p.yaxis.minor_tick_line_color = None
    p.xaxis.major_label_text_color = None
    p.yaxis.major_label_text_color = None
    p.grid.grid_line_color = None   
#    output_file(os.path.join(DATA_PATH, data_folder, 
#                             'outer-viz-test-{0}.html'.format(data_folder.split('-')[3])))
#    save(p)
    show(p)
#     p.output_backend = "svg"
#     export_svgs(p, filename=os.path.join(DATA_PATH, data_folder, 
#                              'outer-viz-test-{0}.svg'.format(data_folder.split('-')[3])))

In [21]:
disease_viz(ds_snomed, enc_file)

ADHD 2455
MM 1935
AD 2496
PD 2518
PC 2519
BC 2462
CD 2521
T2D 2503
19409 19409 (19409, 2)


In [22]:
outer_viz(ds_snomed, enc_file, cl_lab_snomed)

7 694
4 4482
5 853
1 2082
3 5163
2 2197
6 2315
0 1623
19409 19409 (19409, 2)
