In [1]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import random
import sys
sys.path.append('../..')
from modules.many_features import utils, constants
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def create_sankey(df, title, save, filename): #shows true and false positives 
    pos_df = df[df.y_actual == df.y_pred]
    neg_df = df[df.y_actual != df.y_pred]
    utils.draw_sankey_diagram(pos_df, neg_df, title, save, filename)

In [3]:
def create_sankey_df(df):
    overall_tup_dict = utils.generate_tuple_dict(df)
    sankey_df = pd.DataFrame()
    sankey_df['Label1'] = [i[0] for i in overall_tup_dict.keys()]
    sankey_df['Label2'] = [i[1] for i in overall_tup_dict.keys()]
    sankey_df['value'] = list(overall_tup_dict.values())
    #added
    anemia_class = [i for i in sankey_df.Label2 if ((('anemia' in i) | ('Anemia' in i)) |('Inconclusive' in i))][0]
#     end_row = pd.DataFrame({'Label1': anemia_class, 'Label2': '', 'value':10**-10}, index=[0])
#     sankey_df = pd.concat([sankey_df.iloc[:], end_row]).reset_index(drop=True)
    sankey_df['Label1'] = sankey_df['Label1'].replace({'Anemia of chronic disease': 'ACD', 'Iron deficiency anemia':'IDA'})
    sankey_df['Label2'] = sankey_df['Label2'].replace({'Anemia of chronic disease': 'ACD', 'Iron deficiency anemia':'IDA'})   
    return sankey_df

In [4]:
def draw_sankey_diagram(pos_df, neg_df, title, save=False, filename=False):
    pos_sankey_df = create_sankey_df(pos_df)
    neg_sankey_df = create_sankey_df(neg_df)
    unique_actions = list(set(list(pos_sankey_df['Label1'].unique()) + list(pos_sankey_df['Label2'].unique()) + list(neg_sankey_df['Label1'].unique()) + list(neg_sankey_df['Label2'].unique())))
    dmap = dict(zip(unique_actions, range(len(unique_actions))))
    
    pos_sankey_df = create_source_and_target(pos_sankey_df, dmap)
    neg_sankey_df = create_source_and_target(neg_sankey_df, dmap)
    #nodes_color = get_colors(len(dmap))
    nodes_color = ['green' if (('anemia' in node)|('ACD' in node)) else 'orange' for node in unique_actions]    
    
    label = unique_actions
    
    target = list(pos_sankey_df['target']) + list(neg_sankey_df['target'])
    value = list(pos_sankey_df['value']) + list(neg_sankey_df['value'])
    source = list(pos_sankey_df['source']) + list(neg_sankey_df['source'])
    link_color = ['green']*len(pos_sankey_df) + ['red']*len(neg_sankey_df)
#     layout = go.Layout(

# )
    fig = go.Figure(data=[go.Sankey(
        node = dict(pad=15, thickness=20, line=dict(color='black', width=0.5), label=label, color=nodes_color),
        link= dict(source=source, target=target, value=value, color=link_color)
    )])
    fig.update_layout(title_text=title, 
                      title_x=0.5,  
                      title_font_size=24, 
                      title_font_color='black', 
                      title_font_family='Times New Roman', 
                      font = dict(family='Times New Roman', size=38),
                      paper_bgcolor='rgba(0, 0, 0, 0)',
                      plot_bgcolor='rgba(0, 0, 0, 0)'
                      )
    
    if save:
        fig.write_html(f'{filename}.html')
    fig.show()


#### Test df

In [5]:
test_df = pd.read_csv('../../../anemia_ml4hc/test_dfs/dueling_dqn_pr_test_df_seed_63_10000000.csv')
test_df.head()

Unnamed: 0,index,episode_length,reward,y_pred,y_actual,trajectory,terminated,is_success
0,0.0,5.0,1.0,5.0,5.0,"['hemoglobin', 'gender', 'rbc', 'ret_count', '...",0.0,1.0
1,1.0,6.0,1.0,1.0,1.0,"['hemoglobin', 'gender', 'rbc', 'mcv', 'segmen...",0.0,1.0
2,2.0,8.0,1.0,4.0,4.0,"['hemoglobin', 'rbc', 'mcv', 'ferritin', 'hema...",0.0,1.0
3,3.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0
4,4.0,6.0,-1.0,7.0,7.0,"['hemoglobin', 'gender', 'rbc', 'mcv', 'segmen...",1.0,1.0


In [6]:
utils.test(test_df.y_actual, test_df.y_pred)

(97.18571428571428, 97.1643780144586, 98.38007040050316)

#### delete from here

In [7]:
sample_df = test_df[test_df.y_pred==0]

In [8]:
pos_df = sample_df[sample_df.y_actual == sample_df.y_pred]
neg_df = sample_df[sample_df.y_actual != sample_df.y_pred]
pos_df.head()

Unnamed: 0,index,episode_length,reward,y_pred,y_actual,trajectory,terminated,is_success
3,3.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0
9,9.0,3.0,1.0,0.0,0.0,"['hemoglobin', 'gender', 'No anemia']",0.0,1.0
17,17.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0
21,21.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0
32,32.0,3.0,1.0,0.0,0.0,"['hemoglobin', 'gender', 'No anemia']",0.0,1.0


In [9]:
pos_sankey_df = create_sankey_df(pos_df)
neg_sankey_df = create_sankey_df(neg_df)
pos_sankey_df

Unnamed: 0,Label1,Label2,value
0,hemoglobin,No anemia,337
1,hemoglobin,gender,1052
2,gender,No anemia,1663
3,hemoglobin,hematocrit,611
4,hematocrit,gender,611


In [10]:
unique_actions = list(set(list(pos_sankey_df['Label1'].unique()) + list(pos_sankey_df['Label2'].unique()) + list(neg_sankey_df['Label1'].unique()) + list(neg_sankey_df['Label2'].unique())))
unique_actions

['hemoglobin', 'hematocrit', 'gender', 'No anemia']

In [11]:
dmap = dict(zip(unique_actions, range(len(unique_actions))))
dmap

{'hemoglobin': 0, 'hematocrit': 1, 'gender': 2, 'No anemia': 3}

In [12]:
pos_sankey_df = utils.create_source_and_target(pos_sankey_df, dmap)
neg_sankey_df = utils.create_source_and_target(neg_sankey_df, dmap)
pos_sankey_df.head()

Unnamed: 0,Label1,Label2,value,source,target
0,hemoglobin,No anemia,337,0,3
1,hemoglobin,gender,1052,0,2
3,hemoglobin,hematocrit,611,0,1
4,hematocrit,gender,611,1,2
2,gender,No anemia,1663,2,3


In [13]:
pos_color = 'LimeGreen'
neg_color = 'Red'

In [14]:
# pos_mask = pos_sankey_df['value'] >= 0.5
# neg_mask = neg_sankey_df['value'] >= 0.5
# pos_sankey_df['color'] = np.where(pos_mask, pos_color, 'white')
# neg_sankey_df['color'] = np.where(neg_mask, neg_color, 'white')

In [15]:
nodes_color = ['darkorchid' if (('anemia' in node)|('ACD' in node)) else 'orange' for node in unique_actions]  
nodes_color

['orange', 'orange', 'orange', 'darkorchid']

In [16]:
label = unique_actions
label

['hemoglobin', 'hematocrit', 'gender', 'No anemia']

In [18]:
title = 'No anemia pathways'
target = list(pos_sankey_df['target']) + list(neg_sankey_df['target'])
value = list(pos_sankey_df['value']) + list(neg_sankey_df['value'])
source = list(pos_sankey_df['source']) + list(neg_sankey_df['source'])
link_color1 = [pos_color]*len(pos_sankey_df) + [neg_color]*len(neg_sankey_df)
# link_color2 = pos_sankey_df['color'].tolist() + neg_sankey_df['color'].tolist()
link_color1

['LimeGreen', 'LimeGreen', 'LimeGreen', 'LimeGreen', 'LimeGreen', 'Red', 'Red']

In [20]:
label

['hemoglobin', 'hematocrit', 'gender', 'No anemia']

In [22]:
# custom_hover_text = [
#     f"Source: {hoverinformation[0]}<br>Target: {hoverinformation[1]}<br>Value: {value}"
#     for source, target, value in zip(source, target, value)
# ]

node_hover_templates = ['Terminal node<br>acc:acc<br>f1:f1<br>ROC-AUC:roc_auc' if (('anemia' in node)|('ACD' in node)) else 'Node: node name<br>Value: value' for node in label]

fig = go.Figure(data=[go.Sankey(
    node = dict(pad=15, 
                thickness=20, 
                line=dict(color='black', width=0.5), 
                label=label, 
                color=nodes_color, 
                hovertemplate='%{customdata}', 
                customdata=node_hover_templates),
    link= dict(source=source, 
               target=target, 
               value=value, 
               color=link_color1, 
#                hovertemplate='%{customdata}',
#                customdata=custom_hover_text,
              )
)])
fig.update_layout(title_text=title, 
                  title_x=0.5,  
                  title_font_size=24, 
                  title_font_color='black', 
                  title_font_family='Times New Roman', 
                  font = dict(family='Times New Roman', size=18, color='black'),
                  paper_bgcolor='rgba(0, 0, 0, 0)',
                  plot_bgcolor='rgba(0, 0, 0, 0)'
                  )

#### geetting the data distribution of the source

#### end here

In [None]:
# utils.plot_classification_report(test_df['y_actual'], test_df['y_pred'])

In [None]:
# utils.plot_confusion_matrix(test_df['y_actual'], test_df['y_pred'])

In [None]:
# create_sankey(test_df, 'Overall pathways', True, 
#               filename = '../../../anemia_ml4hc/pathways/seed_0/test_df_dueling_dqn_pr_seed_147_basic_10000000')

#### Success df

In [None]:
# success_df = pd.read_csv('../../final/test_dfs/dqn_success_mcv_rbc_50_9000000.csv')
# create_sankey(success_df, 'Pathways of successful episodes', save=False, 
#               filename='../../final/pathways/success_df_mcv_rbc_50_9000000')

In [None]:
# utils.draw_sankey_diagram(success_df, 'Pathways of successful episodes', save=True, 
#                          filename='../../pathways/many_features/0.1/correlated/tsuccess_df3_noisy6_230000000')

In [None]:
# for i in range(constants.CLASS_NUM):
#     print(utils.anemias[i])
#     anemia_df = test_df[test_df.y_pred==i]
#     if len(anemia_df!=0):
#         utils.draw_sankey_diagram(anemia_df, utils.generate_title(i, len(anemia_df)), save=True, 
#                                   filename=f'../../pathways/many_features/0.1/correlated/{utils.generate_filename(i)}_noisy6_23000000')

In [None]:
#precision - shows true positives and false positives
for i in range(constants.CLASS_NUM):
    print(utils.anemias[i])
    anemia_df = test_df[test_df.y_pred==i]
    if len(anemia_df != 0):
        create_sankey(anemia_df, utils.generate_title(i, len(anemia_df)), save=True, 
        filename=f'../../../anemia_ml4hc/pathways/seed_147/{utils.generate_filename(i)}_dueling_dqn_pr_basic_seed_147_10000000')

In [None]:
# recall - shows true positives and false negatives
for i in range(constants.CLASS_NUM):
    anemia_df = test_df[test_df.y_actual == i]
    if len(anemia_df != 0):
        create_sankey(anemia_df, utils.generate_title(i, len(anemia_df)), save=True, 
        filename=f'../../../anemia_ml4hc/pathways/seed_147/recall/{utils.generate_filename(i)}_dueling_dqn_pr_basic_seed_147_10000000')
        

#### For mlhc - to delete

In [None]:
misdiag_df = test_df[(test_df.y_actual==5) & (test_df.y_pred.isin([1, 3]))]
create_sankey(misdiag_df, 'misdiagnosed hemolytic episodes', save=False, filename='')