In [1]:
import pandas as pd
import numpy as np
import random
import sys
sys.path.append('../..')
from modules.many_features import utils
import matplotlib.pyplot as plt
%matplotlib inline

#### The data

In [2]:
test_df = pd.read_csv('../../../anemia_ml4hc/test_dfs/dueling_dqn_pr_test_df_seed_63_10000000.csv')
test_df.head()

Unnamed: 0,index,episode_length,reward,y_pred,y_actual,trajectory,terminated,is_success
0,0.0,5.0,1.0,5.0,5.0,"['hemoglobin', 'gender', 'rbc', 'ret_count', '...",0.0,1.0
1,1.0,6.0,1.0,1.0,1.0,"['hemoglobin', 'gender', 'rbc', 'mcv', 'segmen...",0.0,1.0
2,2.0,8.0,1.0,4.0,4.0,"['hemoglobin', 'rbc', 'mcv', 'ferritin', 'hema...",0.0,1.0
3,3.0,2.0,1.0,0.0,0.0,"['hemoglobin', 'No anemia']",0.0,1.0
4,4.0,6.0,-1.0,7.0,7.0,"['hemoglobin', 'gender', 'rbc', 'mcv', 'segmen...",1.0,1.0


In [3]:
utils.test(test_df.y_actual, test_df.y_pred)

(97.18571428571428, 97.1643780144586, 98.38007040050316)

In [4]:
testing_df = pd.read_csv('../../final/data/test_set_constant.csv')
testing_df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,7.116363,-1.0,3.781573,2.738413,-1.0,95.904198,68.457895,2.226085,0,1.892912,39.80855,110.329197,64.40435,21.654404,73.787009,21.349089,-1.0,5
1,8.12532,92.230003,4.231419,1.188039,143.365567,104.057204,204.747831,2.342554,0,0.652614,13.478089,-1.0,32.705481,-1.0,43.520272,24.375961,142.815207,1
2,11.30945,38.324563,-1.0,-1.0,455.077909,76.402602,-1.0,4.440732,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,33.92835,-1.0,4
3,13.763858,253.513394,2.262606,0.551444,453.772884,82.781943,90.101466,4.987993,0,0.853521,104.005514,34.639227,0.963866,22.083012,88.891838,41.291574,19.856071,0
4,11.464002,-1.0,-1.0,-1.0,320.964653,104.287127,-1.0,3.297819,0,1.163516,121.616315,105.895897,-1.0,9.337462,-1.0,34.392007,-1.0,7


#### Important methods

In [21]:
def tup_to_string(tup):
    tuple_string = ', '.join(f"'{item}'" for item in tup)
    return tuple_string[1:-1]

In [22]:
def get_pathway_thresholds(row, df):
    path_test_df = df[df.trajectory.str.contains(row['link_string'], case=False)]
    path_testing_df = testing_df.loc[path_test_df.index]
    assert((path_test_df.y_actual == path_testing_df.label).all())
    if row.Parent != 'gender':
        mean = path_testing_df[row.Parent].mean()
        std = path_testing_df[row.Parent].std()
        min_ = path_testing_df[row.Parent].min()
        max_= path_testing_df[row.Parent].max()
        median = path_testing_df[row.Parent].median()
        return mean, std, min_, max_, median
    else:
        if len(path_testing_df) > 1:
            try:
                females = path_testing_df.gender.value_counts(normalize=True).mul(100)[0].round(1)
            except:
                females = 0
            try:
                males = path_testing_df.gender.value_counts(normalize=True).mul(100)[1].round(1)
            except:
                males = 0
        else:
            gender = path_testing_df.gender.unique()[0]
            if gender == 0:
                females = 100
                males = 0
            elif gender == 1:
                females =0
                males = 100
        return females, males, np.nan, np.nan, np.nan    

In [31]:
def create_sankey_df(df):
    if len(df) ==0:
        return pd.DataFrame()
    overall_tup_dict = utils.generate_tuple_dict(df)
    sankey_df = pd.DataFrame()
    sankey_df['Parent'] = [i[0] for i in overall_tup_dict.keys()]
    sankey_df['Child'] = [i[1] for i in overall_tup_dict.keys()]
    sankey_df['value'] = list(overall_tup_dict.values())
    sankey_df['link_string'] = [tup_to_string(tup) for tup in overall_tup_dict.keys()]
    sankey_df[['feat_mean', 'feat_std', 'feat_min', 'feat_max', 'feat_median']] = sankey_df.apply(
    get_pathway_thresholds, args=(df,), axis=1, result_type='expand')
    anemia_class = [i for i in sankey_df.Child if ((('anemia' in i) | ('Anemia' in i)) |('Inconclusive' in i))][0]
    sankey_df = sankey_df.append({'Parent':'hemoglobin'}, ignore_index=True)
    sankey_df = sankey_df.append({'Child':anemia_class}, ignore_index=True)
    sankey_df['Parent'] = sankey_df['Parent'].replace({'Anemia of chronic disease': 'ACD', 'Iron deficiency anemia':'IDA'})
    sankey_df['Child'] = sankey_df['Child'].replace({'Anemia of chronic disease': 'ACD', 'Iron deficiency anemia':'IDA'})   
    sankey_df = sankey_df.append({'Parent':'hemoglobin'}, ignore_index=True)
    sankey_df = sankey_df.append({'Child':anemia_class}, ignore_index=True)
    return sankey_df

In [32]:
results_dicts = [{'label':0, 'name': 'No anemia', 'precision':0.92, 'recall':1.00, 'f1':0.96}, 
                 {'label':1, 'name': 'Vitamin B12/Folate deficiency anemia', 'precision':0.96, 'recall':0.98, 'f1':0.97},
                 {'label':2, 'name': 'Unspecified anemia', 'precision':1.00, 'recall':0.98, 'f1':0.99},
                 {'label':3, 'name': 'Anemia of chronic disease', 'precision':0.99, 'recall':0.97, 'f1':0.98},
                 {'label':4, 'name': 'Iron deficiency anemia', 'precision':0.98, 'recall':0.98, 'f1':0.98},
                 {'label':5, 'name': 'Hemolytic anemia', 'precision':1.00, 'recall':0.94, 'f1':0.97},
                 {'label':6, 'name': 'Aplastic anemia', 'precision': 1.00, 'recall':0.94, 'f1':0.97},
                 {'label':7, 'name': 'Inconclusive diagnosis', 'precision':0.94, 'recall':0.97, 'f1':0.95},]

#### delete from here

In [27]:
overall_tup_dict = utils.generate_tuple_dict(anem_test_df)
overall_tup_dict

{('hemoglobin', 'No anemia'): 337,
 ('hemoglobin', 'gender'): 1216,
 ('gender', 'No anemia'): 1827,
 ('hemoglobin', 'hematocrit'): 611,
 ('hematocrit', 'gender'): 611}

#### end here

#### No anemia

In [33]:
anem_dict = results_dicts[0]
anem_test_df = test_df[test_df.y_pred == anem_dict['label']]
anem_sankey_df = create_sankey_df(anem_test_df)
anem_sankey_df

Unnamed: 0,Parent,Child,value,link_string,feat_mean,feat_std,feat_min,feat_max,feat_median
0,hemoglobin,No anemia,337.0,"hemoglobin', 'No anemia",13.599523,0.261987,13.13246,14.040715,13.604035
1,hemoglobin,gender,1216.0,"hemoglobin', 'gender",13.693527,1.239616,11.870551,15.600242,14.080769
2,gender,No anemia,1827.0,"gender', 'No anemia",61.7,38.3,,,
3,hemoglobin,hematocrit,611.0,"hemoglobin', 'hematocrit",16.433996,0.461841,15.601169,17.1977,16.46906
4,hematocrit,gender,611.0,"hematocrit', 'gender",49.301989,1.385524,46.803506,51.5931,49.407179
5,hemoglobin,,,,,,,,
