In [1]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import random
import sys
sys.path.append('../..')
from modules.many_features import utils, constants
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def create_sankey(df, title, save, filename): #shows true and false positives 
    pos_df = df[df.y_actual == df.y_pred]
    neg_df = df[df.y_actual != df.y_pred]
    utils.draw_sankey_diagram(pos_df, neg_df, title, save, filename)

In [3]:
def draw_sankey_diagram(pos_df, neg_df, title, save=False, filename=False):
    pos_sankey_df = create_sankey_df(pos_df)
    neg_sankey_df = create_sankey_df(neg_df)
    unique_actions = list(set(list(pos_sankey_df['Label1'].unique()) + list(pos_sankey_df['Label2'].unique()) + list(neg_sankey_df['Label1'].unique()) + list(neg_sankey_df['Label2'].unique())))
    dmap = dict(zip(unique_actions, range(len(unique_actions))))
    
    pos_sankey_df = utils.create_source_and_target(pos_sankey_df, dmap)
    neg_sankey_df = utils.create_source_and_target(neg_sankey_df, dmap)
    #nodes_color = get_colors(len(dmap))
    nodes_color = ['darkorchid' if (('anemia' in node)|('ACD' in node)) else 'orange' for node in unique_actions]
    label = unique_actions
    
    target = list(pos_sankey_df['target']) + list(neg_sankey_df['target'])
    value = list(pos_sankey_df['value']) + list(neg_sankey_df['value'])
    source = list(pos_sankey_df['source']) + list(neg_sankey_df['source'])
    link_color = [pos_color]*len(pos_sankey_df) + [neg_color]*len(neg_sankey_df)

    fig = go.Figure(data=[go.Sankey(
        node = dict(pad=15, thickness=20, line=dict(color='black', width=0.5), label=label, color=nodes_color,
                    hovertemplate='<b>%{label} (%{value})</b> %{customdata}<extra></extra>',
                    customdata=['<br>precision: 0.92<br>recall: 1.00<br>f1:0.96' if (('anemia' in node)|('ACD' in node)) else '' for node in label]
                   ),
        link= dict(source=source, target=target, value=value, color=link_color, 
                   customdata = [str(i) for i in range(len(source))],
                   hovertemplate='<b>%{source.label} to %{target.label}</b><br>value %{value}<br>%{source.label} distribution'+
                   '<br>mean:90 &#177; 2.100<br>25th percentile: 25<br>50th percentile: 50<br>75th percentile: 75<extra></extra>'
        
                  )
    )])
    fig.update_layout(title_text=title, 
                      title_x=0.5,  
                      title_font_size=24, 
                      title_font_color='black', 
                      title_font_family='Times New Roman', 
                      font = dict(family='Times New Roman', size=20),
                      paper_bgcolor='rgba(0, 0, 0, 0)',
                      plot_bgcolor='rgba(0, 0, 0, 0)'
                      )
    
    if save:
        fig.write_html(f'{filename}.html')
    fig.show()
    return pos_sankey_df, neg_sankey_df

#### Important data

In [4]:
test_df = pd.read_csv('../../../anemia_ml4hc/test_dfs/dueling_dqn_pr_test_df_seed_63_10000000.csv')
test_df.head()

Unnamed: 0,index,episode_length,reward,y_pred,y_actual,trajectory,terminated,is_success
0,0.0,5.0,1.0,5.0,5.0,"['hemoglobin', 'gender', 'rbc', 'ret_count', '...",0.0,1.0
1,1.0,6.0,1.0,1.0,1.0,"['hemoglobin', 'gender', 'rbc', 'mcv', 'segmen...",0.0,1.0
2,2.0,8.0,1.0,4.0,4.0,"['hemoglobin', 'rbc', 'mcv', 'ferritin', 'hema...",0.0,1.0
3,3.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0
4,4.0,6.0,-1.0,7.0,7.0,"['hemoglobin', 'gender', 'rbc', 'mcv', 'segmen...",1.0,1.0


In [5]:
utils.test(test_df.y_actual, test_df.y_pred)

(97.18571428571428, 97.1643780144586, 98.38007040050316)

In [6]:
testing_df = pd.read_csv('../../final/data/test_set_constant.csv')
testing_df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,7.116363,-1.0,3.781573,2.738413,-1.0,95.904198,68.457895,2.226085,0,1.892912,39.80855,110.329197,64.40435,21.654404,73.787009,21.349089,-1.0,5
1,8.12532,92.230003,4.231419,1.188039,143.365567,104.057204,204.747831,2.342554,0,0.652614,13.478089,-1.0,32.705481,-1.0,43.520272,24.375961,142.815207,1
2,11.30945,38.324563,-1.0,-1.0,455.077909,76.402602,-1.0,4.440732,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,33.92835,-1.0,4
3,13.763858,253.513394,2.262606,0.551444,453.772884,82.781943,90.101466,4.987993,0,0.853521,104.005514,34.639227,0.963866,22.083012,88.891838,41.291574,19.856071,0
4,11.464002,-1.0,-1.0,-1.0,320.964653,104.287127,-1.0,3.297819,0,1.163516,121.616315,105.895897,-1.0,9.337462,-1.0,34.392007,-1.0,7


In [7]:
overall_test_df = pd.concat([test_df, testing_df], axis=1)
overall_test_df.head()

Unnamed: 0,index,episode_length,reward,y_pred,y_actual,trajectory,terminated,is_success,hemoglobin,ferritin,...,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,0.0,5.0,1.0,5.0,5.0,"['hemoglobin', 'gender', 'rbc', 'ret_count', '...",0.0,1.0,7.116363,-1.0,...,0,1.892912,39.80855,110.329197,64.40435,21.654404,73.787009,21.349089,-1.0,5
1,1.0,6.0,1.0,1.0,1.0,"['hemoglobin', 'gender', 'rbc', 'mcv', 'segmen...",0.0,1.0,8.12532,92.230003,...,0,0.652614,13.478089,-1.0,32.705481,-1.0,43.520272,24.375961,142.815207,1
2,2.0,8.0,1.0,4.0,4.0,"['hemoglobin', 'rbc', 'mcv', 'ferritin', 'hema...",0.0,1.0,11.30945,38.324563,...,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,33.92835,-1.0,4
3,3.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0,13.763858,253.513394,...,0,0.853521,104.005514,34.639227,0.963866,22.083012,88.891838,41.291574,19.856071,0
4,4.0,6.0,-1.0,7.0,7.0,"['hemoglobin', 'gender', 'rbc', 'mcv', 'segmen...",1.0,1.0,11.464002,-1.0,...,0,1.163516,121.616315,105.895897,-1.0,9.337462,-1.0,34.392007,-1.0,7


In [8]:
overall_test_df.shape

(14000, 26)

In [9]:
pos_color = 'LimeGreen'
neg_color = 'Red'

In [None]:
def create_sankey_df(df):
    overall_tup_dict = utils.generate_tuple_dict(df)
    sankey_df = pd.DataFrame()
    sankey_df['Label1'] = [i[0] for i in overall_tup_dict.keys()]
    sankey_df['Label2'] = [i[1] for i in overall_tup_dict.keys()]
    sankey_df['value'] = list(overall_tup_dict.values())
    sankey_df['link string'] = [tup_to_string(tup) for tup in overall_tup_dict.keys()]
    #added
    anemia_class = [i for i in sankey_df.Label2 if ((('anemia' in i) | ('Anemia' in i)) |('Inconclusive' in i))][0]
    sankey_df['Label1'] = sankey_df['Label1'].replace({'Anemia of chronic disease': 'ACD', 'Iron deficiency anemia':'IDA'})
    sankey_df['Label2'] = sankey_df['Label2'].replace({'Anemia of chronic disease': 'ACD', 'Iron deficiency anemia':'IDA'})   
    return sankey_df

#### No anemia

In [11]:
class_label = 0
anem_test_df = overall_test_df[overall_test_df.y_pred==class_label]
pos_df = anem_test_df[anem_test_df.y_actual == anem_test_df.y_pred]
neg_df = anem_test_df[anem_test_df.y_actual != anem_test_df.y_pred]
pos_df.head()

Unnamed: 0,index,episode_length,reward,y_pred,y_actual,trajectory,terminated,is_success,hemoglobin,ferritin,...,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
3,3.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0,13.763858,253.513394,...,0,0.853521,104.005514,34.639227,0.963866,22.083012,88.891838,41.291574,19.856071,0
9,9.0,3.0,1.0,0.0,0.0,"['hemoglobin', 'gender', 'No anemia']",0.0,1.0,12.490978,341.360589,...,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,37.472935,-1.0,0
17,17.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0,13.89078,150.582769,...,0,0.838111,133.588356,-1.0,-1.0,-1.0,-1.0,41.672339,-1.0,0
21,21.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0,13.8574,32.918871,...,0,1.567181,46.566975,-1.0,-1.0,-1.0,-1.0,41.5722,-1.0,0
32,32.0,3.0,1.0,0.0,0.0,"['hemoglobin', 'gender', 'No anemia']",0.0,1.0,14.749014,97.856707,...,1,1.855077,0.794619,-1.0,-1.0,-1.0,-1.0,44.247043,-1.0,0


#### delete from here

In [12]:
overall_tup_dict = utils.generate_tuple_dict(pos_df)
overall_tup_dict

{('hemoglobin', 'No anemia'): 337,
 ('hemoglobin', 'gender'): 1052,
 ('gender', 'No anemia'): 1663,
 ('hemoglobin', 'hematocrit'): 611,
 ('hematocrit', 'gender'): 611}

In [15]:
def tup_to_string(tup):
    tuple_string = ', '.join(f"'{item}'" for item in tup)
    return tuple_string[1:-1]

In [16]:
sankey_df = pd.DataFrame()
sankey_df['Label1'] = [i[0] for i in overall_tup_dict.keys()]
sankey_df['Label2'] = [i[1] for i in overall_tup_dict.keys()]
sankey_df['link string'] = [tup_to_string(tup) for tup in overall_tup_dict.keys()]
sankey_df.head()

Unnamed: 0,Label1,Label2,link string
0,hemoglobin,No anemia,"hemoglobin', 'No anemia"
1,hemoglobin,gender,"hemoglobin', 'gender"
2,gender,No anemia,"gender', 'No anemia"
3,hemoglobin,hematocrit,"hemoglobin', 'hematocrit"
4,hematocrit,gender,"hematocrit', 'gender"


#### end here

In [None]:
pos_sankey_df, neg_sankey_df = draw_sankey_diagram(pos_df, neg_df, 'No anemia pathways')

In [None]:
pos_sankey_df

#### delete from here

In [None]:
# fig = go.Figure(data=[go.Sankey(
#     node = dict(pad=15, 
#                 thickness=20, 
#                 line=dict(color='black', width=0.5), 
#                 label=label, 
#                 color=nodes_color, 
#                 hovertemplate='<b>%{label} (%{value})</b> %{customdata}<extra></extra>',
#         customdata=['<br>precision: 0.92<br>recall: 1.00<br>f1:0.96' if (('anemia' in node)|('ACD' in node)) else '' for node in label]),
#     link= dict(source=source, 
#                target=target, 
#                value=value, 
#                color=link_color1, 
#                customdata = [str(i) for i in range(len(source))],
#                hovertemplate='<b>%{source.label} to %{target.label}</b><br />value %{value}<br>%{source.label} distribution'+
#                '<br>mean:90 &#177; 2.100<br>25th percentile: 25<br>50th percentile: 50<br>75th percentile: 75<extra></extra>'
        
#               )
# )])
# fig.update_layout(title_text=title, 
#                   title_x=0.5,  
#                   title_font_size=24, 
#                   title_font_color='black', 
#                   title_font_family='Times New Roman', 
#                   font = dict(family='Times New Roman', size=18, color='black'),
#                   paper_bgcolor='rgba(0, 0, 0, 0)',
#                   plot_bgcolor='rgba(0, 0, 0, 0)'
#                   )

#### geetting the data distribution of the source

#### end here

In [None]:
# utils.plot_classification_report(test_df['y_actual'], test_df['y_pred'])

In [None]:
# utils.plot_confusion_matrix(test_df['y_actual'], test_df['y_pred'])

In [None]:
# create_sankey(test_df, 'Overall pathways', True, 
#               filename = '../../../anemia_ml4hc/pathways/seed_0/test_df_dueling_dqn_pr_seed_147_basic_10000000')

#### Success df

In [None]:
# success_df = pd.read_csv('../../final/test_dfs/dqn_success_mcv_rbc_50_9000000.csv')
# create_sankey(success_df, 'Pathways of successful episodes', save=False, 
#               filename='../../final/pathways/success_df_mcv_rbc_50_9000000')

In [None]:
# utils.draw_sankey_diagram(success_df, 'Pathways of successful episodes', save=True, 
#                          filename='../../pathways/many_features/0.1/correlated/tsuccess_df3_noisy6_230000000')

In [None]:
# for i in range(constants.CLASS_NUM):
#     print(utils.anemias[i])
#     anemia_df = test_df[test_df.y_pred==i]
#     if len(anemia_df!=0):
#         utils.draw_sankey_diagram(anemia_df, utils.generate_title(i, len(anemia_df)), save=True, 
#                                   filename=f'../../pathways/many_features/0.1/correlated/{utils.generate_filename(i)}_noisy6_23000000')

In [None]:
#precision - shows true positives and false positives
for i in range(constants.CLASS_NUM):
    print(utils.anemias[i])
    anemia_df = test_df[test_df.y_pred==i]
    if len(anemia_df != 0):
        create_sankey(anemia_df, utils.generate_title(i, len(anemia_df)), save=True, 
        filename=f'../../../anemia_ml4hc/pathways/seed_147/{utils.generate_filename(i)}_dueling_dqn_pr_basic_seed_147_10000000')

In [None]:
# recall - shows true positives and false negatives
for i in range(constants.CLASS_NUM):
    anemia_df = test_df[test_df.y_actual == i]
    if len(anemia_df != 0):
        create_sankey(anemia_df, utils.generate_title(i, len(anemia_df)), save=True, 
        filename=f'../../../anemia_ml4hc/pathways/seed_147/recall/{utils.generate_filename(i)}_dueling_dqn_pr_basic_seed_147_10000000')
        

#### For mlhc - to delete

In [None]:
misdiag_df = test_df[(test_df.y_actual==5) & (test_df.y_pred.isin([1, 3]))]
create_sankey(misdiag_df, 'misdiagnosed hemolytic episodes', save=False, filename='')