In [62]:
import pandas as pd
import numpy as np
import random
import sys
sys.path.append('../../..')
from modules.many_features import utils, lupus_constants
import plotly.graph_objects as go
import matplotlib.pyplot as plt
%matplotlib inline

#### Methods to use

In [63]:
def generate_filename(i):
    classes = list(lupus_constants.CLASS_DICT.keys())
    the_class = classes[i]
    filename = the_class.lower().replace(' ', '_').replace('/','_')
    return filename

In [64]:
get_colors = lambda n: list(map(lambda i: "#" + "%06x" % random.randint(0, 0xFFFFFF),range(n)))

In [65]:
def draw_sankey_diagram_orig(df, title, compress=False, compress_val=0, save=False, filename=False):
    overall_tuple_dict = utils.generate_tuple_dict(df)
    sankey_df = pd.DataFrame()
    sankey_df['Label1'] = [i[0] for i in overall_tuple_dict.keys()]
    sankey_df['Label2'] = [i[1] for i in overall_tuple_dict.keys()]
    sankey_df['value'] = list(overall_tuple_dict.values())
#     sankey_df.sort_values(by=['value'], inplace = True)
    if compress:
        sankey_df = sankey_df[sankey_df.value>compress_val]
    unique_actions = list(set(list(sankey_df['Label1'].unique())  + list(sankey_df['Label2'].unique())))
    dmap = dict(zip(unique_actions, range(len(unique_actions))))
    sankey_df['source'] = sankey_df['Label1'].map(dmap)
    sankey_df['target'] = sankey_df['Label2'].map(dmap)
    sankey_df.sort_values(by=['source'], inplace=True)
    nodes_color = get_colors(len(dmap))
    label = unique_actions
    target = list(sankey_df['target'])
    value = list(sankey_df['value'])
    source = list(sankey_df['source'])
    
#     return sankey_df
    
    link_color = get_colors(len(value))
#     link_color = ['purple']*len(value)
    fig = go.Figure(data=[go.Sankey(
        node = dict(pad=15, thickness=20, line=dict(color='black', width=0.5), label=label, color=nodes_color),
        link = dict(source=source, target=target, value=value, color=link_color)
    )])
    fig.update_layout(title_text=title, title_x=0.5,  title_font_size=24, title_font_color='black', 
                      title_font_family='Times New Roman')
    if save:
        fig.write_html(f'{filename}.html')
    fig.show()

#### Test df

In [66]:
test_df = pd.read_csv('../../../lupus_trial/test_dfs/27_jan/dqn_test_df_modified_11000000.csv')
test_df.head()

Unnamed: 0,index,episode_length,reward,y_pred,y_actual,trajectory,terminated,is_success
0,0.0,5.0,1.0,0.0,0.0,"['ana', 'delirium', 'psychosis', 'seizure', 'N...",0.0,1.0
1,1.0,5.0,1.0,0.0,0.0,"['ana', 'delirium', 'psychosis', 'seizure', 'N...",0.0,1.0
2,2.0,7.0,1.0,0.0,0.0,"['ana', 'delirium', 'psychosis', 'seizure', 't...",0.0,1.0
3,3.0,3.0,1.0,1.0,1.0,"['ana', 'seizure', 'Lupus']",0.0,1.0
4,4.0,3.0,1.0,1.0,1.0,"['ana', 'seizure', 'Lupus']",0.0,1.0


In [68]:
draw_sankey_diagram_orig(test_df, 'Overall pathways', save=True, compress=True,
                         filename='../../../lupus_trial/pathways/27_jan/test_df_modified_11000000.csv')

#### Success df

In [69]:
success_df = pd.read_csv('../../../lupus_trial/test_dfs/27_jan/dqn_success_df_modified_11000000.csv')
draw_sankey_diagram_orig(test_df, 'Overall pathways', save=True, compress=True
                         filename='../../../lupus_trial/pathways/27_jan/success_df_modified_11000000')

#### The 2 classes

In [70]:
#precision - shows true positives and false positives
for i in range(lupus_constants.CLASS_NUM):
    class_df = test_df[test_df.y_pred==i]
    if len(class_df != 0):
        draw_sankey_diagram_orig(class_df, utils.generate_title(i, len(class_df)), save=True, 
              filename=f'../../../lupus_trial/pathways/27_jan/{utils.generate_filename(i)}_simp_bin_100000')