In [1]:
import pandas as pd
import numpy as np
import random
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline
import sys
sys.path.append('../')
from modules import utils, constants
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix, classification_report 
from sklearn.metrics import roc_curve, auc

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



  "stable-baselines is in maintenance mode, please use [Stable-Baselines3 (SB3)](https://github.com/DLR-RM/stable-baselines3) for an up-to-date version. You can find a [migration guide](https://stable-baselines3.readthedocs.io/en/master/guide/migration.html) in SB3 documentation."


In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

#### The datasets

In [3]:
train_df = pd.read_csv('../data/train_set_basic.csv')
train_df = train_df.fillna(-1)
X_train = train_df.iloc[:, 0:-1]
y_train = train_df.iloc[:, -1]
X_train.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat
0,14.728733,-1.0,3.170892,-1.0,-1.0,-1.0,-1.0,-1.0,1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,44.1862,-1.0
1,10.405752,9.634615,5.659537,-1.0,-1.0,77.413788,212.671838,4.032519,0,0.88713,96.311597,-1.0,43.218595,-1.0,83.207518,31.217256,-1.0
2,15.132737,358.914888,1.842252,3.797487,315.102272,80.500314,-1.0,5.639507,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,45.398211,-1.0
3,11.340169,-1.0,1.662209,2.441767,-1.0,97.033963,102.079062,3.506041,1,1.020527,127.281715,-1.0,20.847013,-1.0,62.210273,34.020508,-1.0
4,6.691485,-1.0,3.337971,-1.0,-1.0,99.838438,24.119564,2.010694,0,1.957666,34.633063,-1.0,34.612121,-1.0,112.411298,20.074456,-1.0


In [4]:
test_df = pd.read_csv('../data/test_set_constant.csv')
X_test = test_df.iloc[:, 0:-1]
y_test = test_df.iloc[:, -1]
X_test.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat
0,7.116363,-1.0,3.781573,2.738413,-1.0,95.904198,68.457895,2.226085,0,1.892912,39.80855,110.329197,64.40435,21.654404,73.787009,21.349089,-1.0
1,8.12532,92.230003,4.231419,1.188039,143.365567,104.057204,204.747831,2.342554,0,0.652614,13.478089,-1.0,32.705481,-1.0,43.520272,24.375961,142.815207
2,11.30945,38.324563,-1.0,-1.0,455.077909,76.402602,-1.0,4.440732,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,33.92835,-1.0
3,13.763858,253.513394,2.262606,0.551444,453.772884,82.781943,90.101466,4.987993,0,0.853521,104.005514,34.639227,0.963866,22.083012,88.891838,41.291574,19.856071
4,11.464002,-1.0,-1.0,-1.0,320.964653,104.287127,-1.0,3.297819,0,1.163516,121.616315,105.895897,-1.0,9.337462,-1.0,34.392007,-1.0


#### The model

In [5]:
from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier(random_state=SEED).fit(X_train, y_train)
y_pred_dt = dt.predict(X_test)
test_df_dt = pd.DataFrame()
test_df_dt['y_actual'] = y_test
test_df_dt['y_pred'] = y_pred_dt
test_df_dt.isna().sum()

y_actual    0
y_pred      0
dtype: int64

In [6]:
accuracy_score(test_df_dt.y_actual, test_df_dt.y_pred)*100

99.93571428571428

In [8]:
# from sklearn import tree
# text_representation = tree.export_text(dt, 
#                                        feature_names=train_df.columns[:-1].tolist(),  
#                                        )
# print(text_representation)

In [9]:
def generate_dt_trajectory(row):
    trajectory = []
    if row.mcv <= 100:
        trajectory.append('mcv')
        if row.hematocrit <=39:
            trajectory.append('hematocrit')
            if row.mcv <= 80:
                trajectory.append('mcv')
                if row.ferritin <=100:
                    trajectory.append('ferritin')
                    if row.ferritin < -0.49:
                        # trajectory.append('ferritin')
                        if row.rbc <=0.63:
                            trajectory.append('rbc')
                            trajectory.append('No anemia')
                            return trajectory
                        else: # if rbc > 0.63
                            trajectory.append('rbc')
                            if row.hemoglobin <=13:
                                trajectory.append('hemoglobin')
                                if row.rbc <= 4.87:
                                    trajectory.append('rbc')
                                    if row.hemoglobin <= 12.17:
                                        trajectory.append('hemoglobin')
                                        trajectory.append('Inconclusive diagnosis')
                                        return trajectory
                                    else: #hemoglobin > 12.17
                                        trajectory.append('hemoglobin')
                                        if row.gender <=0.50:
                                            trajectory.append('gender')
                                            trajectory.append('No anemia')
                                            return trajectory 
                                        else: #gender > 0.50
                                            trajectory.append('gender')
                                            trajectory.append('Inconclusive diagnosis')
                                            return trajectory 
                                else: # rbc > 4.87
                                    trajectory.append('rbc')
                                    if row.gender <=0.50:
                                        trajectory.append('gender')
                                        trajectory.append('No anemia')
                                        return trajectory 
                                    else: #gender > 0.50
                                        trajectory.append('gender')
                                        trajectory.append('Inconclusive diagnosis')
                                        return trajectory 
                            else: #hemoglobin > 13
                                trajectory.append('hemoglobin')
                                trajectory.append('No anemia')
                                return trajectory 
                    else: #ferritin > -0.49
                        # trajectory.append('ferritin')
                        if row.tibc <=449.75:
                            trajectory.append('tibc')
                            if row.ferritin <=30.06:
                                trajectory.append('ferritin')
                                if row.hematocrit<=38.71:
                                    trajectory.append('hematocrit')
                                    if row.hematocrit <=36.01:
                                        # trajectory.append('hematocrit')
                                        trajectory.append('Iron deficiency anemia')
                                        return trajectory 
                                    else: #hematocrit > 36.01
                                        # trajectory.append('hematocrit')
                                        if row.gender <=0.50:
                                            trajectory.append('gender')
                                            trajectory.append('No anemia')
                                            return trajectory 
                                        else: #gender > 0.50
                                            trajectory.append('gender')
                                            trajectory.append('Iron deficiency anemia')
                                            return trajectory 
                                else: #hematocrit > 38.71
                                    trajectory.append('hematocrit')
                                    if row.hematocrit <= 38.74:
                                        # trajectory.append('hematocrit')
                                        trajectory.append('No anemia')
                                        return trajectory 
                                    else: #hematocrit > 38.74
                                        # trajectory.append('hematocrit')
                                        trajectory.append('Iron deficiency anemia')
                                        return trajectory 
                            else: #ferritin > 30.06
                                trajectory.append('ferritin')
                                if row.tibc <= 49.53:
                                    trajectory.append('tibc')
                                    trajectory.append('Inconclusive diagnosis')
                                    return trajectory 
                                else: #if tibc > 49.53
                                    trajectory.append('tibc')
                                    if row.copper <=119.54:
                                        trajectory.append('copper')
                                        if row.rbc <= 4.96:
                                            trajectory.append('rbc')
                                            trajectory.append('Anemia of chronic disease')
                                            return trajectory 
                                        else: # if rbc > 4.96
                                            trajectory.append('rbc')
                                            if row.gender <=0.50:
                                                trajectory.append('gender')
                                                trajectory.append('No anemia')
                                                return trajectory 
                                            else: #if gender > 0.50
                                                trajectory.append('gender')
                                                trajectory.append('Anemia of chronic disease')
                                                return trajectory 
                                    else: #copper > 119.54
                                        trajectory.append('copper')
                                        if row.hemoglobin <=12.02:
                                            trajectory.append('hemoglobin')
                                            trajectory.append('Anemia of chronic disease')
                                            return trajectory 
                                        else: #hemoglobin > 12.02
                                            trajectory.append('hemoglobin')
                                            if row.gender <= 0.50:
                                                trajectory.append('gender')
                                                trajectory.append('No anemia')
                                                return trajectory
                                            else: #gender > 0.50
                                                trajectory.append('gender')
                                                trajectory.append('Anemia of chronic disease')
                                                return trajectory 
                        else: #tibc > 449.75
                            trajectory.append('tibc')
                            if row.rbc <= 5.14:
                                trajectory.append('rbc')
                                if row.rbc <= 4.79:
                                    # trajectory.append('rbc')
                                    trajectory.append('Iron deficiency anemia')
                                    return trajectory 
                                else: # rbc > 4.79
                                    # trajectory.append('rbc')
                                    if row.gender < 0.50:
                                        trajectory.append('gender')
                                        trajectory.append('No anemia')
                                        return trajectory 
                                    else: #gender > 0.50
                                        trajectory.append('gender')
                                        trajectory.append('Iron deficiency anemia')
                                        return trajectory 
                            else: #rbc > 5.14
                                trajectory.append('rbc')
                                if row.hematocrit <= 38.68:
                                    trajectory.append('hematocrit')
                                    trajectory.append('No anemia')
                                    return trajectory 
                                else: # hematocrit > 38.68
                                    trajectory.append('hematocrit')
                                    trajectory.append('Iron deficiency anemia')
                                    return trajectory 
                else: #ferritin > 100
                    trajectory.append('ferritin')
                    if row.hematocrit <= 36.00:
                        trajectory.append('hematocrit')
                        trajectory.append('Anemia of chronic disease')
                        return trajectory 
                    else: #hematocrit > 36.00
                        trajectory.append('hematocrit')
                        if row.gender <=0.50:
                            trajectory.append('gender')
                            trajectory.append('No anemia')
                            return trajectory 
                        else: #gender > 0.50
                            trajectory.append('gender')
                            trajectory.append('Anemia of chronic disease')
                            return trajectory 
            else: #mcv > 80
                trajectory.append('mcv')
                if row.ret_count <= 2.10:
                    trajectory.append('ret_count')
                    if row.ret_count <= -0.50:
                        trajectory.append('ret_count')
                        if row.hemoglobin <= 12.00:
                            trajectory.append('hemoglobin')
                            trajectory.append('Inconclusive diagnosis')
                            return trajectory 
                        else: # if hemoglobin > 12
                            trajectory.append('hemoglobin')
                            if row.gender <=0.50:
                                trajectory.append('gender')
                                trajectory.append('No anemia')
                                return trajectory 
                            else:
                                trajectory.append('gender')
                                trajectory.append('Inconclusive diagnosis')
                                return trajectory 
                    else: # ret_count > 0.50
                        trajectory.append('ret_count')
                        if row.hematocrit <= 36.52:
                            trajectory.append('hematocrit')
                            if row.ret_count < 2.05:
                                trajectory.append('ret_count')
                                if row.hemoglobin <= 12.03:
                                    trajectory.append('hemoglobin')
                                    trajectory.append('Aplastic anemia')
                                    return trajectory 
                                else: #hemoglobin > 12.03
                                    trajectory.append('hemoglobin')
                                    if row.gender <= 0.50:
                                        trajectory.append('gender')
                                        trajectory.append('No anemia')
                                        return trajectory 
                                    else: #gender > 0.50
                                        trajectory.append('gender')
                                        trajectory.append('Aplastic anemia')
                                        return trajectory 
                            else: #ret_count > 2.05
                                trajectory.append('ret_count')
                                trajectory.append('No anemia')
                                return trajectory 
                        else: #hematocrit > 36.52
                            trajectory.append('hematocrit')
                            if row.gender <=0.50:
                                trajectory.append('gender')
                                trajectory.append('No anemia')
                                return trajectory 
                            else:
                                trajectory.append('gender')
                                trajectory.append('Aplastic anemia')
                                return trajectory 
                else: #ret_count > 2.10
                    trajectory.append('ret_count')
                    if row.hemoglobin <= 12:
                        trajectory.append('hemoglobin')
                        if row.mcv <=99.99:
                            trajectory.append('mcv')
                            trajectory.append('Hemolytic anemia')
                            return trajectory 
                        else: #mcv > 99.99
                            trajectory.append('mcv')
                            trajectory.append('Inconclusive diagnosis')
                            return trajectory 
                    else: #hemoglobin >12
                        trajectory.append('hemoglobin')
                        if row.gender <=0.50:
                            trajectory.append('gender')
                            trajectory.append('No anemia')
                            return trajectory 
                        else: #gender >0.50
                            trajectory.append('gender')
                            trajectory.append('Hemolytic anemia')
                            return trajectory 
        else: #hematocrit > 39
            trajectory.append('hematocrit')
            if row.hematocrit <= 39:
                # trajectory.append('hematocrit')
                if row.glucose <= 49.62:
                    trajectory.append('glucose')
                    trajectory.append('Inconclusive diagnosis')
                    return trajectory 
                else: #glucose > 49.62
                    trajectory.append('glucose')
                    trajectory.append('No anemia')
                    return trajectory 
            else: #hematocrit > 39
                # trajectory.append('hematocrit')
                trajectory.append('No anemia')
                return trajectory 
    else: #mcv >100
        trajectory.append('mcv')
        if row.segmented_neutrophils <= 0.05:
            trajectory.append('segmented_neutrophils')
            if row.segmented_neutrophils <= -0.5:
                # trajectory.append('segmented_neutrophils')
                if row.hemoglobin <= 12.99:
                    trajectory.append('hemoglobin')
                    if row.hematocrit <= 36.01:
                        trajectory.append('hematocrit')
                        trajectory.append('Inconclusive diagnosis')
                        return trajectory 
                    else: #hematocrit > 36.01
                        trajectory.append('hematocrit')
                        if row.gender <= 0.50:
                            trajectory.append('gender')
                            trajectory.append('No anemia')
                            return trajectory 
                        else: #gender > 0.50
                            trajectory.append('gender')
                            trajectory.append('Inconclusive diagnosis')
                            return trajectory 
                else: #hemoglobin > 12.99
                    trajectory.append('hemoglobin')
                    trajectory.append('No anemia')
                    return trajectory 
            else: #segmented_neutrophils > -0.50
                # trajectory.append('segmented_neutrophils')
                trajectory.append('Unspecified anemia')
                return trajectory 
        else: #segmented_neutrophils > 0.05
            trajectory.append('segmented_neutrophils')
            if row.hemoglobin <= 13.02:
                trajectory.append('hemoglobin')
                if row.hematocrit <= 36.12:
                    trajectory.append('hematocrit')
                    trajectory.append('Vitamin B12/Folate deficiency anemia')
                    return trajectory
                else: #hematocrit > 36.12
                    trajectory.append('hematocrit')
                    if row.gender <=0.50:
                        trajectory.append('gender')
                        trajectory.append('No anemia')
                        return trajectory 
                    else: #gender > 0.50
                        trajectory.append('gender')
                        trajectory.append('Vitamin B12/Folate deficiency anemia')
                        return trajectory 
            else: #hemoglobin > 13.02
                trajectory.append('hemoglobin')
                trajectory.append('No anemia')
                return trajectory                                  

In [10]:
test_df_dt['trajectory'] = X_test.apply(lambda row: generate_dt_trajectory(row), axis=1)
test_df_dt['refined_trajectory'] = test_df_dt['trajectory'].apply(lambda row: [*set(row)])
test_df_dt.head()

Unnamed: 0,y_actual,y_pred,trajectory,refined_trajectory
0,5,5,"[mcv, hematocrit, mcv, ret_count, hemoglobin, ...","[mcv, Hemolytic anemia, hematocrit, ret_count,..."
1,1,1,"[mcv, segmented_neutrophils, hemoglobin, hemat...","[mcv, hematocrit, segmented_neutrophils, Vitam..."
2,4,4,"[mcv, hematocrit, mcv, ferritin, tibc, rbc, Ir...","[mcv, ferritin, rbc, hematocrit, tibc, Iron de..."
3,0,0,"[mcv, hematocrit, No anemia]","[hematocrit, mcv, No anemia]"
4,7,7,"[mcv, segmented_neutrophils, hemoglobin, hemat...","[mcv, hematocrit, segmented_neutrophils, Incon..."


In [11]:
test_df_dt['episode_length'] = test_df_dt['trajectory'].apply(lambda row: len(row))
test_df_dt['refined_episode_length'] = test_df_dt['refined_trajectory'].apply(lambda row: len(row))
test_df_dt.head()

Unnamed: 0,y_actual,y_pred,trajectory,refined_trajectory,episode_length,refined_episode_length
0,5,5,"[mcv, hematocrit, mcv, ret_count, hemoglobin, ...","[mcv, Hemolytic anemia, hematocrit, ret_count,...",7,5
1,1,1,"[mcv, segmented_neutrophils, hemoglobin, hemat...","[mcv, hematocrit, segmented_neutrophils, Vitam...",5,5
2,4,4,"[mcv, hematocrit, mcv, ferritin, tibc, rbc, Ir...","[mcv, ferritin, rbc, hematocrit, tibc, Iron de...",7,6
3,0,0,"[mcv, hematocrit, No anemia]","[hematocrit, mcv, No anemia]",3,3
4,7,7,"[mcv, segmented_neutrophils, hemoglobin, hemat...","[mcv, hematocrit, segmented_neutrophils, Incon...",5,5


In [12]:
test_df_dt.episode_length.min(), test_df_dt.episode_length.mean(), test_df_dt.episode_length.max()

(3, 6.057214285714286, 11)

In [13]:
test_df_dt.refined_episode_length.min(), test_df_dt.refined_episode_length.mean(), test_df_dt.refined_episode_length.max()

(3, 4.6470714285714285, 8)

In [19]:
test_df_dt[test_df_dt.refined_episode_length == test_df_dt.refined_episode_length.max()].iloc[0]['refined_trajectory']

['mcv',
 'ferritin',
 'rbc',
 'hematocrit',
 'copper',
 'gender',
 'tibc',
 'Anemia of chronic disease']

In [17]:
a = ['mcv', 'hematocrit', 'mcv', 'ferritin', 'tibc', 'ferritin', 'tibc', 'copper', 'rbc', 'gender', 
     'Anemia of chronic disease']

In [18]:
[*set(a)]

['mcv',
 'ferritin',
 'rbc',
 'hematocrit',
 'copper',
 'gender',
 'tibc',
 'Anemia of chronic disease']