In [None]:
from bokeh.plotting import figure, show
from bokeh.models import (RangeTool, ColumnDataSource, DataTable,
                          DateFormatter, TableColumn, HoverTool)
from bokeh.layouts import column
from bokeh.models.layouts import TabPanel, Tabs
import pandas as pd
from bokeh.io import output_notebook
from bokeh.palettes import all_palettes
import numpy as np
from bokeh.transform import jitter
import random
import pickle


In [None]:
data_types = ['Medication', 'Procedure', 'Labs', 'Diagnoses']
colors = all_palettes['Viridis'][len(data_types)]

In [None]:
def process_data():
    date_name = ['order_time_jittered', 'proc_start_time_jittered',
                 'order_time', 'start_date']
    enc_col_name = ['pat_enc_csn_id_coded', 'pat_enc_csn_id_coded',
                    'pat_enc_csn_id_coded', 'pat_enc_csn_id_jittered']

    with open('Data/encounters.pkl', 'rb') as f:
        encounter_df = pickle.load(f).rename(columns={'pat_enc_csn_id_coded': 'encounter_id',
                 'appt_when_jittered': 'start_date',
                 'hosp_dischrg_time_jittered': 'end_date'})

    encounter_df['start_date'] = \
        pd.to_datetime(encounter_df['start_date'])
    encounter_df['end_date'] = \
        pd.to_datetime(encounter_df['end_date']).dt.tz_localize(None)
    encounter_df['end_date'] = encounter_df['end_date'].fillna(encounter_df['start_date']+np.timedelta64(23, 'h'))


    category, date, encounter_id, event_id, event_name, event_icd9, event_icd10 = [], [], [], [], [], [], []

    type_dict = {
        'Medication': {'id': 'medication_id', 'name': 'med_description'},
        'Procedure': {'id': 'cpt_code', 'name': 'description'},
        'Labs': {'id': 'proc_code', 'name': 'group_lab_name'},
        'Diagnoses': {'id': 'dx_id', 'name': 'dx_name'},

    }

    for type, date_colname, enc_colname in zip(data_types, date_name, enc_col_name):
        with open(f'Data/{type}.pkl', 'rb') as f:
            df = pickle.load(f)
        num = df.shape[0]
        category += [type] * num
        date += pd.to_datetime(df[date_colname]).dt.normalize().to_list()
        encounter_id += df[enc_colname].to_list()
        icd9, icd10 = [None] * num, [None] * num

        category_id = type_dict[f"{type}"]['id']
        event_id += df[f"{category_id}"].to_list()

        category_name = type_dict[f"{type}"]['name']
        event_name += df[f"{category_name}"].to_list()

        if type == 'Diagnoses':
            icd9 = df['icd9'].replace({np.nan: None}).tolist()
            icd10 = df['icd10'].replace({np.nan: None}).tolist()

        event_icd9 += icd9
        event_icd10 += icd10

    df = pd.DataFrame({
        'category': category,
        'date': date,
        'encounter_id': encounter_id,
        'event_id': event_id,
        'event_name': event_name,
        'ICD9': event_icd9,
        'ICD10': event_icd10
    })

    df.dropna(subset=['encounter_id'], inplace=True)
    df['encounter_id'] = df['encounter_id'].astype(int)

    return df, encounter_df

In [None]:
# Various helper functions
def get_jittered_times(df, date):
    seven_minus = lambda x: x - np.timedelta64(7, 'h')
    three_minus = lambda x: x - np.timedelta64(3, 'h')
    one_minus = lambda x: x - np.timedelta64(1, 'h')
    one_plus = lambda x: x + np.timedelta64(1, 'h')
    three_plus = lambda x: x + np.timedelta64(3, 'h')
    five_plus = lambda x: x + np.timedelta64(5, 'h')
    seven_plus = lambda x: x + np.timedelta64(7, 'h')
    eleven_plus = lambda x: x + np.timedelta64(11, 'h')

    jittered_events = {
        'Medication': [seven_minus(df[date]), one_minus(df[date])],
        'Procedure': [three_minus(df[date]), three_plus(df[date])],
        'Labs': [one_plus(df[date]), seven_plus(df[date])],
        'Diagnoses': [five_plus(df[date]), eleven_plus(df[date])]
    }

    for event_type in data_types:
        df[f"jittered_start_{event_type}"] = jittered_events[f"{event_type}"][0]
        df[f"jittered_end_{event_type}"] = jittered_events[f"{event_type}"][1]

def get_limits(df):
    max_val, row_id = 0, -1
    for event in data_types:
        max_row_id = [df[f"{event}"].idxmax()][0]
        local_max = df.loc[max_row_id, f"{event}"]
        if local_max > max_val:
            row_id = max_row_id
            max_val = local_max
    return max_val, row_id

def jitter_func(x):
    delta = random.randint(-480, 480)
    return x + np.timedelta64(delta, 'm')

def create_selection_tool(reference_graph, x_range, dates, grouped_df):
    selection_tool = figure(
        title="Drag the middle and edges of the selection box to change the range above",
        height=100, width=2000,
        tools='xpan',
        y_range=reference_graph.y_range,
        x_axis_type='datetime',
        y_axis_type=None,
        background_fill_color="#efefef"
    )

    if x_range is not None:
        selection_tool = figure(
            title="Drag the middle and edges of the selection box to change the range above",
            height=100, width=2000,
            tools='xpan',
            y_range=reference_graph.y_range,
            x_range=x_range,
            x_axis_type='datetime',
            y_axis_type=None,
            background_fill_color="#efefef"
        )

    range_tool = RangeTool(x_range=reference_graph.x_range, y_range=reference_graph.y_range)
    range_tool.overlay.fill_color = "navy"
    range_tool.overlay.fill_alpha = 0.2

    for event, c in zip(data_types, colors):
        selection_tool.vbar(x=dates, top=grouped_df[f'{event}'], width=20)

    selection_tool.ygrid.grid_line_color = None
    selection_tool.add_tools(range_tool)
    selection_tool.toolbar.active_multi = 'auto'

    return selection_tool

In [None]:

# Create Encounter-Binned Tab
def plot_encounter_tab(df, encounter_df, tab_list):
    # Group data by encounter id
    grouped = df.groupby([pd.Grouper(key='encounter_id'), 'category']).size().unstack(fill_value=0). \
        reset_index().set_index('encounter_id')
    encounter_df = encounter_df.set_index('encounter_id')

    # Merge granular dataframe (encounter_df) with grouped dataframe (grouped)
    merged = grouped.join(encounter_df, lsuffix='_l', rsuffix='_r').reset_index()
    merged['total_events'] = merged['Diagnoses'] + merged['Labs'] + merged['Medication'] + merged['Procedure']
    merged.dropna(subset=['start_date'], inplace=True)
    merged.sort_values(by=['start_date'], inplace=True)
    merged['y_start'] = list(range(0, 2 * len(merged), 2))
    merged['y_end'] = merged['y_start'].map(np.array).add(2)
    get_jittered_times(merged, 'start_date')
    source = ColumnDataSource(merged)

    max_val, row_id = get_limits(merged)
    dates = np.array(merged['start_date'].dt.date, dtype=np.datetime64)

    encounter_grouped = figure(
        x_axis_type='datetime',
        x_axis_location='above',
        width=1500,
        height=300,
        tools=['pan', 'xpan', 'box_zoom', 'wheel_zoom', 'xbox_select', 'save', 'reset'],
        active_drag='xbox_select'
    )

    hover = HoverTool()
    hover.tooltips = [
        ('Encounter ID', '@encounter_id'),
        ('Encounter Start', '@start_date{%F %T}'),
        ('Encounter End', '@end_date{%F %T}'),
        ('Total', '@total_events'),
        ('Medications', '@Medication'),
        ('Procedures', '@Procedure'),
        ('Labs', '@Labs'),
        ('Diagnoses', '@Diagnoses')
    ]
    hover.formatters = {
        '@start_date': 'datetime', '@end_date': 'datetime'
    }
    encounter_grouped.add_tools(hover)

    encounter_grouped.quad(bottom='y_start', top='y_end', left='start_date', right='end_date',
                           source=source, line_color='black', fill_alpha=0.4)

    encounter_grouped.legend.location = "top_left"
    encounter_grouped.legend.orientation = "horizontal"
    encounter_grouped.legend.click_policy = 'hide'
    encounter_grouped.yaxis.axis_label = 'Encounter Index (not ID)'

    # Create selection graph 1 which allows you to move selection area over a larger timescale
    encounter_scatter_select_all = \
        create_selection_tool(
            reference_graph=encounter_grouped, x_range=None,
            dates=dates, grouped_df=grouped)

    # Create selection graph 2 which allows you to move selection area over a smaller, more refined timescale
    encounter_scatter_select_narrow = \
        create_selection_tool(
            reference_graph=encounter_grouped, x_range=(dates[row_id - 10], dates[row_id + 10]),
            dates=dates, grouped_df=grouped
        )

    # Add all plots and the data table to the tab
    encounter_tab = TabPanel(
        child=column(encounter_grouped,
                     encounter_scatter_select_all,
                     encounter_scatter_select_narrow),
        title="Encounter Binned")
    tab_list += [encounter_tab]

In [None]:
# Create Time-Binned Tab
def plot_time_tab(df, tab_list):
    N = len(df)

    # Group dates by day and do some processing
    grouped = \
        df.groupby(
            [pd.Grouper(key='date', freq='D'), 'category']
        ).size().unstack(fill_value=0).reset_index().dropna(subset=['date'])
    grouped['date'] = np.array(grouped['date'].dt.date, dtype=np.datetime64)

    # Create empty columns
    grouped['event_type'], grouped['event_count'], grouped['event_name'], grouped['event_id'], grouped['enc_id'], \
         grouped['ICD9'], grouped['ICD10'] = \
        [None] * len(grouped), [None] * len(grouped), [None] * len(grouped), \
        [None] * len(grouped), [-1] * len(grouped),  [None] * len(grouped), [None] * len(grouped)

    # Get jittered date columns
    get_jittered_times(grouped, 'date')

    # Prep df so that it can be merged with the grouped dataframe.
    # This prepped version of df is called new_df
    new_df = df.rename(columns={'category': 'event_type', 'encounter_id': 'enc_id'})
    for event_type in data_types:
        new_df[f"{event_type}"] = [-1] * N
        new_df[f"jittered_start_{event_type}"] = pd.Series(dtype='datetime64[s]')
        new_df[f"jittered_end_{event_type}"] = pd.Series(dtype='datetime64[s]')
    new_df['event_count'] = [None] * N

    # Merge grouped dataframe (grouped) with the dataframe that has more granular information (new_df)
    new_df = pd.concat([grouped, new_df.astype(grouped.dtypes)], axis=0, ignore_index=True)
    source = ColumnDataSource(new_df) # Used for bar plot

    # Do additional processing to original df so that it can be used for the scatter plot and data table
    # Do this processing after the merge to avoid overcomplicating the above merging process
    df['jittered_time'] = df['date'].apply(jitter_func)
    df.sort_values(by='date', inplace=True)
    ungrouped_source = ColumnDataSource(df) # Used for scatter plot and data table

    # Used for selection plots
    grouped['end_date'] = pd.DatetimeIndex(grouped['date']) + pd.DateOffset(1)
    grouped.dropna(subset=['date', 'end_date'], inplace=True)
    grouped['date'] = np.array(grouped['date'].dt.date, dtype=np.datetime64)

    max_val, row_id = get_limits(new_df)

    # Create bar graph with events grouped by event type (medications, diagnoses, etc) and plotted by date
    bar_plot = figure(
        x_axis_type='datetime',
        x_axis_location='above',
        width=1500,
        height=300,
        x_range=(grouped['date'][row_id], grouped['date'][row_id + 7]),
        y_range=(0, max_val + 50),
        tools="pan,xpan,box_zoom,wheel_zoom,xbox_select,save,reset",
        active_drag='xbox_select'
    )

    for event, c in zip(data_types, colors):
        bar_plot.quad(bottom=0, top=f'{event}', left=f'jittered_start_{event}',
                               right=f'jittered_end_{event}', color=c, source=source,
                               fill_alpha=0.4, line_color='black', legend_label=f'{event}')

    bar_plot.legend.location = "top_left"
    bar_plot.legend.orientation = "horizontal"
    bar_plot.legend.click_policy = 'hide'
    bar_plot.yaxis.axis_label = 'Count'

    # Create selection graph 1 which allows you to move selection area over a larger timescale
    wide_selection = \
        create_selection_tool(
            reference_graph=bar_plot, x_range=None,
            dates=grouped['date'], grouped_df=grouped
        )

    # Create selection graph 2 which allows you to move selection area over a smaller, more refined timescale
    narrow_selection = \
        create_selection_tool(
            reference_graph=bar_plot,
            x_range=(grouped['date'][row_id - 20], grouped['date'][row_id + 20]),
            dates=grouped['date'], grouped_df=grouped
        )

    # Create scatter plot
    scatter_plot = figure(
        x_axis_type='datetime',
        width=1500,
        height=300,
        x_range=bar_plot.x_range,
        y_range=data_types,
        tools="pan,xpan,box_zoom,wheel_zoom,xbox_select,save,reset",
        active_drag='xbox_select'
    )

    hover = HoverTool(tooltips=[('Event ID', '@event_id'),
                                ('Event Name', '@event_name'),
                                ('Encounter ID', '@encounter_id')
                                ])
    scatter_plot.add_tools(hover)

    scatter_plot.scatter(x='jittered_time',
                       y=jitter('category', width=0.6, range=scatter_plot.y_range),
                       source=ungrouped_source, alpha=0.5)

    # Create data table
    ungrouped_cols = [
        TableColumn(field='date', title='Date', formatter=DateFormatter()),
        TableColumn(field='category', title='Event Type'),
        TableColumn(field='encounter_id', title='Encounter ID'),
        TableColumn(field='event_id', title='Event ID'),
        TableColumn(field='event_name', title='Event Name'),
        TableColumn(field='ICD9', title='ICD9'),
        TableColumn(field='ICD10', title='ICD10')
    ]

    ungrouped_data_table = DataTable(source=ungrouped_source, columns=ungrouped_cols, width=2000)

    # Add all plots and the data table to the tab
    scatter_joint_tab = TabPanel(child=column(bar_plot, wide_selection,
                                              narrow_selection, scatter_plot,
                                              ungrouped_data_table),
                                 title='Time Binned')
    tab_list += [scatter_joint_tab]


In [None]:
def plot_tabs(tab_list):
    tabs = Tabs(tabs=tab_list)
    output_notebook()
    show(tabs)

In [None]:
if __name__ == '__main__':
    tab_list = []
    df, encounter_df = process_data()
    plot_time_tab(df, tab_list)
    plot_encounter_tab(df, encounter_df, tab_list)
    plot_tabs(tab_list)