In [1]:
import pandas as pd
import numpy as np
import ast
import sys
sys.path.append('../..')
from modules.many_features import utils
import plotly.graph_objects as go
import matplotlib.pyplot as plt
%matplotlib inline

#### The constant test dataset

In [2]:
X_test_df = pd.read_csv('../../../anemia_ml4hc/data/test_set_constant.csv')
X_test_df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,7.116363,-1.0,3.781573,2.738413,-1.0,95.904198,68.457895,2.226085,0,1.892912,39.80855,110.329197,64.40435,21.654404,73.787009,21.349089,-1.0,5
1,8.12532,92.230003,4.231419,1.188039,143.365567,104.057204,204.747831,2.342554,0,0.652614,13.478089,-1.0,32.705481,-1.0,43.520272,24.375961,142.815207,1
2,11.30945,38.324563,-1.0,-1.0,455.077909,76.402602,-1.0,4.440732,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,33.92835,-1.0,4
3,13.763858,253.513394,2.262606,0.551444,453.772884,82.781943,90.101466,4.987993,0,0.853521,104.005514,34.639227,0.963866,22.083012,88.891838,41.291574,19.856071,0
4,11.464002,-1.0,-1.0,-1.0,320.964653,104.287127,-1.0,3.297819,0,1.163516,121.616315,105.895897,-1.0,9.337462,-1.0,34.392007,-1.0,7


#### The DQN test df

In [3]:
test_df = pd.read_csv('../../../anemia_ml4hc/test_dfs/dueling_dqn_pr_test_df_seed_63_10000000.csv')
test_df.head()

Unnamed: 0,index,episode_length,reward,y_pred,y_actual,trajectory,terminated,is_success
0,0.0,5.0,1.0,5.0,5.0,"['hemoglobin', 'gender', 'rbc', 'ret_count', '...",0.0,1.0
1,1.0,6.0,1.0,1.0,1.0,"['hemoglobin', 'gender', 'rbc', 'mcv', 'segmen...",0.0,1.0
2,2.0,8.0,1.0,4.0,4.0,"['hemoglobin', 'rbc', 'mcv', 'ferritin', 'hema...",0.0,1.0
3,3.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0
4,4.0,6.0,-1.0,7.0,7.0,"['hemoglobin', 'gender', 'rbc', 'mcv', 'segmen...",1.0,1.0


In [4]:
utils.test(test_df.y_actual, test_df.y_pred)

(97.18571428571428, 97.1643780144586, 98.38007040050316)

In [5]:
a1 = np.array(test_df['y_actual'], dtype=np.float32)
a2 = np.array(X_test_df['label'], dtype=np.float32)
assert np.array_equal(a1, a2)

In [6]:
combined_full_test_df = pd.concat([test_df, X_test_df], axis=1)
combined_full_test_df.head()

Unnamed: 0,index,episode_length,reward,y_pred,y_actual,trajectory,terminated,is_success,hemoglobin,ferritin,...,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,0.0,5.0,1.0,5.0,5.0,"['hemoglobin', 'gender', 'rbc', 'ret_count', '...",0.0,1.0,7.116363,-1.0,...,0,1.892912,39.80855,110.329197,64.40435,21.654404,73.787009,21.349089,-1.0,5
1,1.0,6.0,1.0,1.0,1.0,"['hemoglobin', 'gender', 'rbc', 'mcv', 'segmen...",0.0,1.0,8.12532,92.230003,...,0,0.652614,13.478089,-1.0,32.705481,-1.0,43.520272,24.375961,142.815207,1
2,2.0,8.0,1.0,4.0,4.0,"['hemoglobin', 'rbc', 'mcv', 'ferritin', 'hema...",0.0,1.0,11.30945,38.324563,...,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,33.92835,-1.0,4
3,3.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0,13.763858,253.513394,...,0,0.853521,104.005514,34.639227,0.963866,22.083012,88.891838,41.291574,19.856071,0
4,4.0,6.0,-1.0,7.0,7.0,"['hemoglobin', 'gender', 'rbc', 'mcv', 'segmen...",1.0,1.0,11.464002,-1.0,...,0,1.163516,121.616315,105.895897,-1.0,9.337462,-1.0,34.392007,-1.0,7


#### Some useful functions

In [13]:
def generate_tuple_dict(df):
    frequency_dict = {}
    for traj in df.trajectory:
        if traj in frequency_dict.keys():
            frequency_dict[traj] += 1
        else:
            frequency_dict[traj] = 1
    #print(f'frequency_dict: {frequency_dict}')
    overall_tup_dict = {}
    for key, value in frequency_dict.items():
        new_key = ast.literal_eval(key)
        for tup in zip(new_key, new_key[1:]):
            #print(f'tup: {tup}')
            if tup in overall_tup_dict.keys():
                overall_tup_dict[tup] += value
            else:
                overall_tup_dict[tup] = value
    #print(f'overall_tup_dict: {overall_tup_dict}')
    return overall_tup_dict

In [14]:
def get_threshold_value(row):
    substring = f"'{row.Label1}', '{row.Label2}'"
    substring_df = pd.DataFrame()
    for i, test_row in combined_full_test_df.iterrows():
        if substring in test_row.trajectory:
            substring_df = substring_df.append(test_row)
    min_threshold_value = substring_df[row.Label1].min()
    max_threshold_value = substring_df[row.Label1].max()
    return round(min_threshold_value, 2), round(max_threshold_value) # OR return a string combining them 

In [15]:
def create_sankey_df(df):
    overall_tup_dict = generate_tuple_dict(df)
    sankey_df = pd.DataFrame()
    sankey_df['Label1'] = [i[0] for i in overall_tup_dict.keys()]
    sankey_df['Label2'] = [i[1] for i in overall_tup_dict.keys()]
    sankey_df['value'] = list(overall_tup_dict.values())
#     sankey_df['thresholds'] = sankey_df.apply(lambda row: get_threshold_value(row), axis=1)
    return sankey_df

In [16]:
def create_source_and_target(sankey_df, dmap):
    sankey_df['source'] = sankey_df['Label1'].map(dmap)
    sankey_df['target'] = sankey_df['Label2'].map(dmap)
    sankey_df.sort_values(by=['source'], inplace=True)
    return sankey_df

In [35]:
def draw_sankey_diagram(pos_df, neg_df, title, save=False, filename=False):
    pos_sankey_df = create_sankey_df(pos_df)
    neg_sankey_df = create_sankey_df(neg_df)
    unique_actions = list(set(list(pos_sankey_df['Label1'].unique()) + list(pos_sankey_df['Label2'].unique()) + list(neg_sankey_df['Label1'].unique()) + list(neg_sankey_df['Label2'].unique())))
    dmap = dict(zip(unique_actions, range(len(unique_actions))))
    
    pos_sankey_df = create_source_and_target(pos_sankey_df, dmap)
    neg_sankey_df = create_source_and_target(neg_sankey_df, dmap)
    #nodes_color = get_colors(len(dmap))
    nodes_color = 'orange'
    
    label = unique_actions
    
    target = list(pos_sankey_df['target']) + list(neg_sankey_df['target'])
    value = list(pos_sankey_df['value']) + list(neg_sankey_df['value'])
    source = list(pos_sankey_df['source']) + list(neg_sankey_df['source'])
    link_color = ['yellow']*len(pos_sankey_df) + ['blue']*len(neg_sankey_df)
#     layout = go.Layout(

# )
    fig = go.Figure(data=[go.Sankey(
        node = dict(pad=15, thickness=20, line=dict(color='black', width=0.5), label=label, color=nodes_color),
        link= dict(source=source, target=target, value=value, color=link_color)
    )])
    fig.update_layout(title_text=title, 
                      title_x=0.5,  
                      title_font_size=24, 
                      title_font_color='black', 
                      title_font_family='Times New Roman', 
                      font = dict(family='Times New Roman', size=18),
                      paper_bgcolor='rgba(0, 0, 0, 0)',
                      plot_bgcolor='rgba(0, 0, 0, 0)'
                      )
    
    if save:
        fig.write_html(f'{filename}.html')
    fig.show()

In [36]:
def create_sankey(df, title, save, filename): #shows true and false positives 
    #pos_df = df[df.y_actual == df.y_pred]
    pos_df = df[df.y_==1]
    #neg_df = df[df.y_actual != df.y_pred]
    neg_df = df[df.y_pred==3]
    draw_sankey_diagram(pos_df, neg_df, title, save, filename)

In [37]:
misdiag_hem_df = test_df[(test_df.y_actual==5) & (test_df.y_pred.isin([1, 3]))]
create_sankey(misdiag_hem_df, 'misdiagnosed hemolytic episodes', save=False, filename='')

In [41]:
misdiag_no_df = test_df[(test_df.y_pred == 0) & (test_df.y_actual!=0)]
pos_df = misdiag_no_df[misdiag_no_df.y_actual==1]
neg_df = misdiag_no_df[misdiag_no_df.y_actual==5]
draw_sankey_diagram(pos_df, neg_df, 'Misdiagnosed as no anemia', save=False, filename='')

In [43]:
misdiag_no_df.describe()

Unnamed: 0,index,episode_length,reward,y_pred,y_actual,terminated,is_success
count,44.0,44.0,44.0,44.0,44.0,44.0,44.0
mean,6786.931818,3.0,-1.0,0.0,3.454545,0.0,0.0
std,3984.587157,0.0,0.0,0.0,1.970179,0.0,0.0
min,7.0,3.0,-1.0,0.0,1.0,0.0,0.0
25%,3561.5,3.0,-1.0,0.0,1.0,0.0,0.0
50%,7189.0,3.0,-1.0,0.0,5.0,0.0,0.0
75%,9933.5,3.0,-1.0,0.0,5.0,0.0,0.0
max,13980.0,3.0,-1.0,0.0,5.0,0.0,0.0


#### Drawing a sankey diagram for no anemia - those diagnosed as hemolytic anemia

In [11]:
vit_df = combined_full_test_df[combined_full_test_df.y_pred ==0]
vit_pos_df = vit_df[vit_df.y_actual == vit_df.y_pred]
vit_neg_df = vit_df[vit_df.y_actual != vit_df.y_pred]
len(vit_pos_df), len(vit_neg_df)

(2000, 164)

In [12]:
pos_sankey_df = create_sankey_df(vit_pos_df) #should have another col with min and max value. this should be done
# pos_sankey_df
neg_sankey_df = create_sankey_df(vit_neg_df) # in the create_sankey_df function or evel=n earlier
pos_sankey_df.head()

Unnamed: 0,Label1,Label2,value,thresholds
0,hemoglobin,No anemia,337,"(13.13, 14)"
1,hemoglobin,gender,1052,"(6.11, 16)"
2,gender,No anemia,1663,"(0.0, 1)"
3,hemoglobin,hematocrit,611,"(15.6, 17)"
4,hematocrit,gender,611,"(18.17, 52)"


In [None]:
pos_sankey_df

In [None]:
neg_sankey_df

In [None]:
unique_actions = list(set(list(pos_sankey_df['Label1'].unique()) + list(pos_sankey_df['Label2'].unique()) + list(neg_sankey_df['Label1'].unique()) + list(neg_sankey_df['Label2'].unique())))
dmap = dict(zip(unique_actions, range(len(unique_actions))))
dmap

In [None]:
pos_sankey_df = create_source_and_target(pos_sankey_df, dmap)
neg_sankey_df = create_source_and_target(neg_sankey_df, dmap)
#nodes_color = get_colors(len(dmap))
nodes_color = 'orange'

label = unique_actions
pos_sankey_df.head()

In [None]:
target = list(pos_sankey_df['target']) + list(neg_sankey_df['target'])
value = list(pos_sankey_df['value']) + list(neg_sankey_df['value'])
source = list(pos_sankey_df['source']) + list(neg_sankey_df['source'])
link_color = ['green']*len(pos_sankey_df) + ['red']*len(neg_sankey_df)
target

In [None]:
len(label), len(target), len(value), len(source)

In [None]:
fig = go.Figure(data=[go.Sankey(
        node = dict(pad=15, thickness=20, line=dict(color='black', width=0.5), label=label, color=nodes_color),
        link= dict(source=source, target=target, value=value, color=link_color)
9    )])

In [None]:
fig

#### Drawing a sankey diagram, those misdiagnosed whose actual is hemolytic anemia

In [None]:
# misdiag_as_hem_anem =combined_full_test_df[(combined_full_test_df.y_actual==5) & (combined_full_test_df.y_pred.isin([0, 1]))]
misdiag_hem_samples = combined_full_test_df[(combined_full_test_df.y_pred==1) & (combined_full_test_df.y_actual==5)] 
len(misdiag_hem_samples)

In [None]:
utils.draw_sankey_diagram_orig(misdiag_hem_samples, 'Misdiagnosed hemolytic anemia samples')

In [None]:
misdiag_hem_samples.describe()['hemoglobin']

#### Drawing a sankey diagram for those misdiagnosed with NO anemia

In [None]:
misdiag_as_no_anem = combined_full_test_df[(combined_full_test_df.y_pred==0) & (combined_full_test_df.y_actual!=0)]
len(misdiag_as_no_anem)

In [None]:
utils.draw_sankey_diagram_orig(misdiag_as_no_anem, 'Misdiagnosed as no anemia')