In [6]:
import pandas as pd
import numpy as np
import ast
import seaborn as sns
#import networkx as nx
import plotly.graph_objects as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
#from pyvis.network import Network
#import holoviews as hv
import random
import matplotlib.pyplot as plt
%matplotlib inline

In [7]:
#hv.extension('bokeh')
init_notebook_mode(connected=True)

#### The test dataframe - Example of Iron deficiency anemia

In [10]:
#test_df = pd.read_csv('test_dfs/test_df_with_hb_some_nans_2e6.csv')
#test_df =pd.read_csv('test_dfs/noisy/final_test_df_4e6.csv')
test_df = pd.read_csv('test_dfs/unspecified/test_df_2e6.csv')
test_df.head()

Unnamed: 0,index,episode_length,reward,y_pred,y_actual,trajectory,terminated,is_success
0,0.0,4.0,1.0,5.0,5.0,"['hemoglobin', 'rbc', 'ret_count', 'Hemolytic ...",0.0,1.0
1,1.0,5.0,1.0,1.0,1.0,"['hemoglobin', 'rbc', 'segmented_neutrophils',...",0.0,1.0
2,2.0,5.0,1.0,4.0,4.0,"['hemoglobin', 'rbc', 'ferritin', 'tibc', 'Iro...",0.0,1.0
3,3.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0
4,4.0,4.0,-1.0,7.0,7.0,"['hemoglobin', 'rbc', 'segmented_neutrophils',...",1.0,1.0


In [11]:
# import_dict = {'Hemolytic anemia': 0, 'Anemia of chronic disease': 1, 'Aplastic anemia': 2, 'Iron deficiency anemia': 3,
#                'Vitamin B12/Folate deficiency anemia': 4, 'Thalassemia': 5, 'ferritin':6, 'ret_count':7, 
#                'segmented_neutrophils':8, 'iron':9, 'tibc':10, 'rbc':11, 'mcv':12, 'mentzer_index':13}

#### Util functions

In [12]:
# anemias = ['Hemolytic anemia', 'Anemia of chronic disease', 'Aplastic anemia', 'Iron deficiency anemia', 
#                'Vitamin B12/Folate deficiency anemia', 'Thalassemia']
# anemias = ['No anemia', 'Hemolytic anemia', 'Aplastic anemia', 'Anemia of chronic disease', 
#            'Vitamin B12/Folate deficiency anemia', 'Iron deficiency anemia']
# anemias = ['Aplastic anemia', 'No anemia', 'Hemolytic anemia', 'Anemia of chronic disease', 'Iron deficiency anemia',
#            'Vitamin B12/Folate deficiency anemia']
#i normally use
# anemias = ['No anemia', 'Hemolytic anemia', 'Aplastic anemia', 'Iron deficiency anemia', 'Vitamin B12/Folate deficiency anemia', 
#            'Anemia of chronic disease']
#for initial noisy dataset
# anemias = ['Aplastic anemia', 'Hemolytic anemia', 'No anemia', 'Anemia of chronic disease', 
#            'Vitamin B12/Folate deficiency anemia', 'Iron deficiency anemia']

#final_test_df_4e6
# anemias = ['Hemolytic anemia', 'Anemia of chronic disease', 'No anemia', 'Aplastic anemia', 
#            'Vitamin B12/Folate deficiency anemia','Iron deficiency anemia']
anemias = ['No anemia', 'Hemolytic anemia', 'Aplastic anemia', 'Iron deficiency anemia', 'Vitamin B12/Folate deficiency anemia',
           'Anemia of chronic disease', 'Unspecified anemia']

In [13]:
get_colors = lambda n: list(map(lambda i: "#" + "%06x" % random.randint(0, 0xFFFFFF),range(n)))

In [14]:
def generate_filename(i):
    anemia = anemias[i]
    filename = anemia.lower().replace(' ', '_').replace('/','_')
    return filename

In [15]:
def generate_title(i, patient_num):
    anemia = anemias[i]
    title = f'Diagnosis Pathway for {anemia} - ({patient_num} patients)'
    return title

In [16]:
def generate_tuple_dict(df):
    frequency_dict = {}
    for traj in df.trajectory:
        if traj in frequency_dict.keys():
            frequency_dict[traj] += 1
        else:
            frequency_dict[traj] = 1
    print(f'frequency_dict: {frequency_dict}')
    overall_tup_dict = {}
    for key, value in frequency_dict.items():
        new_key = ast.literal_eval(key)
        for tup in zip(new_key, new_key[1:]):
            #print(f'tup: {tup}')
            if tup in overall_tup_dict.keys():
                overall_tup_dict[tup] += value
            else:
                overall_tup_dict[tup] = value
    print(f'overall_tup_dict: {overall_tup_dict}')
    return overall_tup_dict

In [12]:
def draw_sankey_diagram(df, title, save=False, filename=False):
    overall_tuple_dict = generate_tuple_dict(df)
    sankey_df = pd.DataFrame()
    sankey_df['Label1'] = [i[0] for i in overall_tuple_dict.keys()]
    sankey_df['Label2'] = [i[1] for i in overall_tuple_dict.keys()]
    sankey_df['value'] = list(overall_tuple_dict.values())
    unique_actions = list(set(list(sankey_df['Label1'].unique())  + list(sankey_df['Label2'].unique())))
    dmap = dict(zip(unique_actions, range(len(unique_actions))))
    sankey_df['source'] = sankey_df['Label1'].map(dmap)
    sankey_df['target'] = sankey_df['Label2'].map(dmap)
    print(sankey_df)
    sankey_df.sort_values(by=['source'], inplace=True)
    nodes_color = get_colors(len(dmap))
    label = unique_actions
    target = list(sankey_df['target'])
    value = list(sankey_df['value'])
    source = list(sankey_df['source'])
    link_color = get_colors(len(value))
    fig = go.Figure(data=[go.Sankey(
        node = dict(pad=15, thickness=20, line=dict(color='black', width=0.5), label=label, color=nodes_color),
        link = dict(source=source, target=target, value=value, color=link_color)
    )])
    fig.update_layout(title_text=title, title_x=0.5,  title_font_size=24, title_font_color='black', 
                      title_font_family='Times New Roman')
    if save:
        fig.write_html(f'pathways/unspecified/final_sankey_{filename}.html')
    fig.show()

#### Entire test_df

In [13]:
draw_sankey_diagram(test_df, 'Overall pathways', save=True, filename='test_df')

frequency_dict: {"['hemoglobin', 'mcv', 'ret_count', 'Aplastic anemia']": 2678, "['hemoglobin', 'mcv', 'ret_count', 'Hemolytic anemia']": 4248, "['hemoglobin', 'mcv', 'Unspecified anemia']": 620, "['hemoglobin', 'No anemia']": 2790, "['hemoglobin', 'mcv', 'tibc', 'ferritin', 'Aplastic anemia']": 505, "['hemoglobin', 'mcv', 'No anemia']": 200, "['hemoglobin', 'mcv', 'ret_count', 'ret_count', 'ret_count', 'ret_count', 'ret_count', 'ret_count']": 147, "['hemoglobin', 'mcv', 'tibc', 'ferritin', 'Iron deficiency anemia']": 377, "['hemoglobin', 'mcv', 'ret_count', 'segmented_neutrophils', 'segmented_neutrophils', 'segmented_neutrophils', 'segmented_neutrophils', 'segmented_neutrophils']": 33, "['hemoglobin', 'mcv', 'Vitamin B12/Folate deficiency anemia']": 142, "['hemoglobin', 'mcv', 'ret_count', 'ferritin', 'ferritin', 'ferritin', 'ferritin', 'ferritin']": 19, "['hemoglobin', 'mcv', 'mcv', 'mcv', 'mcv', 'mcv', 'mcv', 'mcv']": 121, "['hemoglobin', 'mcv', 'Anemia of chronic disease']": 4, "['

#### Terminated episodes

In [14]:
terminated_df = test_df[test_df.y_pred.isna()]

In [15]:
draw_sankey_diagram(terminated_df, f'pathways of terminated episodes - {len(terminated_df)}', save=True, 
                    filename='terminated_df')

frequency_dict: {"['hemoglobin', 'mcv', 'ret_count', 'ret_count', 'ret_count', 'ret_count', 'ret_count', 'ret_count']": 147, "['hemoglobin', 'mcv', 'ret_count', 'segmented_neutrophils', 'segmented_neutrophils', 'segmented_neutrophils', 'segmented_neutrophils', 'segmented_neutrophils']": 33, "['hemoglobin', 'mcv', 'ret_count', 'ferritin', 'ferritin', 'ferritin', 'ferritin', 'ferritin']": 19, "['hemoglobin', 'mcv', 'mcv', 'mcv', 'mcv', 'mcv', 'mcv', 'mcv']": 121, "['hemoglobin', 'mcv', 'ret_count', 'mcv', 'mcv', 'mcv', 'mcv', 'mcv']": 28, "['hemoglobin', 'mcv', 'ret_count', 'tibc', 'ferritin', 'ferritin', 'ferritin', 'ferritin']": 4, "['hemoglobin', 'mcv', 'ret_count', 'tibc', 'tibc', 'tibc', 'tibc', 'tibc']": 2, "['hemoglobin', 'mcv', 'ferritin', 'ferritin', 'ferritin', 'ferritin', 'ferritin', 'ferritin']": 29, "['hemoglobin', 'mcv', 'tibc', 'ferritin', 'segmented_neutrophils', 'segmented_neutrophils', 'segmented_neutrophils', 'segmented_neutrophils']": 2, "['hemoglobin', 'mcv', 'ret_co

#### Success df

In [18]:
success_df = pd.read_csv('test_dfs/unspecified/success_df_2e6.csv')

In [19]:
draw_sankey_diagram(success_df, 'Pathways of successful episodes', save=True, filename='success_df')

frequency_dict: {"['hemoglobin', 'mcv', 'ret_count', 'Hemolytic anemia']": 4189, "['hemoglobin', 'mcv', 'Unspecified anemia']": 310, "['hemoglobin', 'No anemia']": 2790, "['hemoglobin', 'mcv', 'ret_count', 'Aplastic anemia']": 2662, "['hemoglobin', 'mcv', 'tibc', 'ferritin', 'Iron deficiency anemia']": 377, "['hemoglobin', 'mcv', 'Vitamin B12/Folate deficiency anemia']": 64, "['hemoglobin', 'mcv', 'No anemia']": 93, "['hemoglobin', 'mcv', 'ret_count', 'No anemia']": 5, "['hemoglobin', 'mcv', 'ret_count', 'ferritin', 'Aplastic anemia']": 6}
overall_tup_dict: {('hemoglobin', 'mcv'): 7706, ('mcv', 'ret_count'): 6862, ('ret_count', 'Hemolytic anemia'): 4189, ('mcv', 'Unspecified anemia'): 310, ('hemoglobin', 'No anemia'): 2790, ('ret_count', 'Aplastic anemia'): 2662, ('mcv', 'tibc'): 377, ('tibc', 'ferritin'): 377, ('ferritin', 'Iron deficiency anemia'): 377, ('mcv', 'Vitamin B12/Folate deficiency anemia'): 64, ('mcv', 'No anemia'): 93, ('ret_count', 'No anemia'): 5, ('ret_count', 'ferriti

#### y_pred_df

In [20]:
y_pred_df = pd.read_csv('test_dfs/unspecified/y_pred_df_2e6.csv')

In [21]:
draw_sankey_diagram(y_pred_df, 'Pathways of completed episodes', save=True, filename='y_pred_df')

frequency_dict: {"['hemoglobin', 'mcv', 'ret_count', 'Aplastic anemia']": 2678, "['hemoglobin', 'mcv', 'ret_count', 'Hemolytic anemia']": 4248, "['hemoglobin', 'mcv', 'Unspecified anemia']": 620, "['hemoglobin', 'No anemia']": 2790, "['hemoglobin', 'mcv', 'tibc', 'ferritin', 'Aplastic anemia']": 505, "['hemoglobin', 'mcv', 'No anemia']": 200, "['hemoglobin', 'mcv', 'tibc', 'ferritin', 'Iron deficiency anemia']": 377, "['hemoglobin', 'mcv', 'Vitamin B12/Folate deficiency anemia']": 142, "['hemoglobin', 'mcv', 'Anemia of chronic disease']": 4, "['hemoglobin', 'mcv', 'ferritin', 'Aplastic anemia']": 14, "['hemoglobin', 'mcv', 'ret_count', 'ferritin', 'Aplastic anemia']": 16, "['hemoglobin', 'mcv', 'ret_count', 'Anemia of chronic disease']": 3, "['hemoglobin', 'mcv', 'ret_count', 'No anemia']": 8, "['hemoglobin', 'mcv', 'tibc', 'No anemia']": 2, "['hemoglobin', 'mcv', 'ret_count', 'Unspecified anemia']": 1, "['hemoglobin', 'mcv', 'tibc', 'Iron deficiency anemia']": 2}
overall_tup_dict: {('

#### The Anemias

In [22]:
for i in range(6):
    print(anemias[i])
    anemia_df = test_df[test_df.y_pred==i]
    if len(anemia_df!=0):
        draw_sankey_diagram(anemia_df, generate_title(i, len(anemia_df)), save=True, filename=generate_filename(i))

No anemia
frequency_dict: {"['hemoglobin', 'No anemia']": 2790, "['hemoglobin', 'mcv', 'No anemia']": 200, "['hemoglobin', 'mcv', 'ret_count', 'No anemia']": 8, "['hemoglobin', 'mcv', 'tibc', 'No anemia']": 2}
overall_tup_dict: {('hemoglobin', 'No anemia'): 2790, ('hemoglobin', 'mcv'): 210, ('mcv', 'No anemia'): 200, ('mcv', 'ret_count'): 8, ('ret_count', 'No anemia'): 8, ('mcv', 'tibc'): 2, ('tibc', 'No anemia'): 2}
       Label1     Label2  value  source  target
0  hemoglobin  No anemia   2790       3       2
1  hemoglobin        mcv    210       3       0
2         mcv  No anemia    200       0       2
3         mcv  ret_count      8       0       1
4   ret_count  No anemia      8       1       2
5         mcv       tibc      2       0       4
6        tibc  No anemia      2       4       2


Hemolytic anemia
frequency_dict: {"['hemoglobin', 'mcv', 'ret_count', 'Hemolytic anemia']": 4248}
overall_tup_dict: {('hemoglobin', 'mcv'): 4248, ('mcv', 'ret_count'): 4248, ('ret_count', 'Hemolytic anemia'): 4248}
       Label1            Label2  value  source  target
0  hemoglobin               mcv   4248       1       0
1         mcv         ret_count   4248       0       2
2   ret_count  Hemolytic anemia   4248       2       3


Aplastic anemia
frequency_dict: {"['hemoglobin', 'mcv', 'ret_count', 'Aplastic anemia']": 2678, "['hemoglobin', 'mcv', 'tibc', 'ferritin', 'Aplastic anemia']": 505, "['hemoglobin', 'mcv', 'ferritin', 'Aplastic anemia']": 14, "['hemoglobin', 'mcv', 'ret_count', 'ferritin', 'Aplastic anemia']": 16}
overall_tup_dict: {('hemoglobin', 'mcv'): 3213, ('mcv', 'ret_count'): 2694, ('ret_count', 'Aplastic anemia'): 2678, ('mcv', 'tibc'): 505, ('tibc', 'ferritin'): 505, ('ferritin', 'Aplastic anemia'): 535, ('mcv', 'ferritin'): 14, ('ret_count', 'ferritin'): 16}
       Label1           Label2  value  source  target
0  hemoglobin              mcv   3213       2       0
1         mcv        ret_count   2694       0       1
2   ret_count  Aplastic anemia   2678       1       3
3         mcv             tibc    505       0       5
4        tibc         ferritin    505       5       4
5    ferritin  Aplastic anemia    535       4       3
6         mcv         ferritin     14       0       4
7   ret_cou

Iron deficiency anemia
frequency_dict: {"['hemoglobin', 'mcv', 'tibc', 'ferritin', 'Iron deficiency anemia']": 377, "['hemoglobin', 'mcv', 'tibc', 'Iron deficiency anemia']": 2}
overall_tup_dict: {('hemoglobin', 'mcv'): 379, ('mcv', 'tibc'): 379, ('tibc', 'ferritin'): 377, ('ferritin', 'Iron deficiency anemia'): 377, ('tibc', 'Iron deficiency anemia'): 2}
       Label1                  Label2  value  source  target
0  hemoglobin                     mcv    379       1       0
1         mcv                    tibc    379       0       4
2        tibc                ferritin    377       4       2
3    ferritin  Iron deficiency anemia    377       2       3
4        tibc  Iron deficiency anemia      2       4       3


Vitamin B12/Folate deficiency anemia
frequency_dict: {"['hemoglobin', 'mcv', 'Vitamin B12/Folate deficiency anemia']": 142}
overall_tup_dict: {('hemoglobin', 'mcv'): 142, ('mcv', 'Vitamin B12/Folate deficiency anemia'): 142}
       Label1                                Label2  value  source  target
0  hemoglobin                                   mcv    142       1       0
1         mcv  Vitamin B12/Folate deficiency anemia    142       0       2


Anemia of chronic disease
frequency_dict: {"['hemoglobin', 'mcv', 'Anemia of chronic disease']": 4, "['hemoglobin', 'mcv', 'ret_count', 'Anemia of chronic disease']": 3}
overall_tup_dict: {('hemoglobin', 'mcv'): 7, ('mcv', 'Anemia of chronic disease'): 4, ('mcv', 'ret_count'): 3, ('ret_count', 'Anemia of chronic disease'): 3}
       Label1                     Label2  value  source  target
0  hemoglobin                        mcv      7       1       0
1         mcv  Anemia of chronic disease      4       0       3
2         mcv                  ret_count      3       0       2
3   ret_count  Anemia of chronic disease      3       2       3


#### Pathway Analysis 

#### Hemolytic anemia

#### Iron Deficiency Anemia(3)

In [None]:
ida_df = test_df[test_df.y_pred==3]
ida_orig_df = df[df.label==3]
ida_df.head()

In [None]:
#rows with tibc in their trajectory
tibc_ida_df = pd.DataFrame()
count=0
for i, row in ida_df.iterrows():
    try:
        traj = ast.literal_eval(row.trajectory)
    except:
        traj = row.trajectory
    if 'tibc' in traj:
        tibc_ida_df = tibc_ida_df.append(row)
tibc_ida_df

In [None]:
array_of_interest = ida_orig_df['ferritin'].values
get_min_of_array(array_of_interest, 50)

#### Aplastic Anemia (2)

In [None]:
aplastic_df = test_df[test_df.y_pred==2]
aplastic_orig_df = df[df.label==2]
aplastic_df.head()

In [None]:
#rows with hypersegmented neutrophils in their trajectory
neutrophils_aplastic_df = pd.DataFrame()
count=0
for i, row in aplastic_df.iterrows():
    try:
        traj = ast.literal_eval(row.trajectory)
    except:
        traj = row.trajectory
    if 'segmented_neutrophils' in traj:
        neutrophils_aplastic_df = neutrophils_aplastic_df.append(row)
neutrophils_aplastic_df

In [None]:
X_test[5052]

#### Vitamin B12 (4)

In [None]:
#misdiagnosed B12
mis_b12 = test_df[((test_df.y_actual==4) & (test_df.y_pred !=4)) & (test_df.y_pred.notna())]
mis_b12

In [None]:
X_test[6627]

In [None]:
X_test_df[X_test_df[2] !=0][2].values.min()

In [None]:
mis_b12.iloc[0]['trajectory']

In [None]:
X_test[6627]

In [None]:
A = X_test_df[X_test_df[2] !=0].values[:, 2]
k= 10
idx = np.argpartition(A, k)
A[idx[:k]]

In [None]:
df.head()

#### delete from here

In [None]:
draw_sankey_diagram(ida_df)

#### Using Plotly

#### Generalize and generate for all slides

#### Using holoviews

In [None]:
def draw_sankey_diagram_hv(source, target, value, filename=False, save=False):
    hv_df = pd.DataFrame()
    hv_df['source'] = source
    hv_df['target'] = target
    hv_df['value'] = value
    hv_sankey = hv.Sankey(hv_df)
    hv_sankey.opts(cmap='Colorblind', label_position='left', edge_color='target', edge_line_width=0, node_alpha=1.0, node_width=40,
               node_sort=True, width=800, height=400, bgcolor='snow', title='Clinical pathways')
    if save:
        hv.save(hv_sankey, f'pathways/{filename}.png', fmt='png')