In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
from datetime import datetime
import random
import os
import torch
import sys
sys.path.append('../..')
from modules.many_features import utils, constants
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

  import pandas.util.testing as tm


In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)

#### The Data

In [3]:
#df = pd.read_csv('../../data/anemia_synth_dataset_some_nans_unspecified_more_feats.csv')
#df = pd.read_csv('../../data/more_feats_0.2.csv')
#df= pd.read_csv('../../data/more_features/more_feats_new_labels_0.1.csv')
#df =pd.read_csv('../../data/more_features/more_feats_new_labels_0.1_noisy_0.6.csv')
df = pd.read_csv('../../data/more_features//more_feats_correlated_noisy_6.csv')
#df = utils.balance_dataset(df, 8000)
df = df.fillna(-1)
df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,12.322384,163.121429,4.121959,-1.0,433.145097,100.147359,51.037057,3.691276,1,0.819482,147.693827,59.199141,41.958432,20.792161,101.383932,36.967153,11.782901,Unspecified anemia
1,8.298889,-1.0,2.07695,-1.0,483.617753,98.431076,-1.0,2.52935,0,-1.0,-1.0,-1.0,36.118322,-1.0,-1.0,24.896668,-1.0,Hemolytic anemia
2,12.696391,3.393723,-1.0,-1.0,451.933132,79.486542,85.001345,4.791902,1,-1.0,4.852168,89.831485,44.946238,0.965963,-1.0,38.089174,18.80839,Iron deficiency anemia
3,12.705102,-1.0,2.305379,-1.0,-1.0,81.057541,135.371313,4.702253,1,1.32414,32.717943,76.524319,-1.0,27.439316,-1.0,38.115305,-1.0,Aplastic anemia
4,8.211543,29.622561,-1.0,0.93619,479.914773,78.38848,-1.0,3.142634,1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,24.634629,-1.0,Iron deficiency anemia


In [4]:
df.isna().sum()

hemoglobin               0
ferritin                 0
ret_count                0
segmented_neutrophils    0
tibc                     0
mcv                      0
serum_iron               0
rbc                      0
gender                   0
creatinine               0
cholestrol               0
copper                   0
ethanol                  0
folate                   0
glucose                  0
hematocrit               0
tsat                     0
label                    0
dtype: int64

In [5]:
utils.get_dt_performance(df)

(0.6596428571428572,
 0.6536656960257498,
 0.8005238687826254,
 datetime.timedelta(microseconds=2998))

In [6]:
df.label.value_counts()

No anemia                               16000
Anemia of chronic disease                8828
Iron deficiency anemia                   8331
Unspecified anemia                       8104
Aplastic anemia                          8093
Hemolytic anemia                         8089
Vitamin B12/Folate deficiency anemia     8082
Inconclusive diagnosis                   4473
Name: label, dtype: int64

In [7]:
class_dict = constants.CLASS_DICT
df['label'] = df['label'].replace(class_dict)
X = df.iloc[:, 0:-1]
y = df.iloc[:, -1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=SEED)
X_train, y_train = np.array(X_train), np.array(y_train)
X_test, y_test = np.array(X_test), np.array(y_test)
X_train.shape, X_test.shape, y_train.shape, y_test.shape

((56000, 17), (14000, 17), (56000,), (14000,))

In [8]:
action_list = list(class_dict.keys()) + [col  for col in df.columns if col!='label']
action_list

['No anemia',
 'Vitamin B12/Folate deficiency anemia',
 'Unspecified anemia',
 'Anemia of chronic disease',
 'Iron deficiency anemia',
 'Hemolytic anemia',
 'Aplastic anemia',
 'Inconclusive diagnosis',
 'hemoglobin',
 'ferritin',
 'ret_count',
 'segmented_neutrophils',
 'tibc',
 'mcv',
 'serum_iron',
 'rbc',
 'gender',
 'creatinine',
 'cholestrol',
 'copper',
 'ethanol',
 'folate',
 'glucose',
 'hematocrit',
 'tsat']

In [9]:
len(action_list)

25

In [10]:
df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,12.322384,163.121429,4.121959,-1.0,433.145097,100.147359,51.037057,3.691276,1,0.819482,147.693827,59.199141,41.958432,20.792161,101.383932,36.967153,11.782901,2
1,8.298889,-1.0,2.07695,-1.0,483.617753,98.431076,-1.0,2.52935,0,-1.0,-1.0,-1.0,36.118322,-1.0,-1.0,24.896668,-1.0,5
2,12.696391,3.393723,-1.0,-1.0,451.933132,79.486542,85.001345,4.791902,1,-1.0,4.852168,89.831485,44.946238,0.965963,-1.0,38.089174,18.80839,4
3,12.705102,-1.0,2.305379,-1.0,-1.0,81.057541,135.371313,4.702253,1,1.32414,32.717943,76.524319,-1.0,27.439316,-1.0,38.115305,-1.0,6
4,8.211543,29.622561,-1.0,0.93619,479.914773,78.38848,-1.0,3.142634,1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,24.634629,-1.0,4


#### Training 

In [11]:
# %%time
# timesteps = int(2e6)
# dqn_model = utils.stable_dqn3(X_train, y_train, timesteps, True, f'../../models/many_features/stable_dqn3_{timesteps}')
# test_df = utils.evaluate_dqn(dqn_model, X_test, y_test)
# test_df.head()

In [None]:
#for steps in [int(6e6), int(6.5e6), int(7e6), int(7.5e6), int(8e6), int(8.5e6), int(9e6)]:
for steps in [int(9.5e6), int(10e6), int(10.5e6), int(11e6), int(11.5e6), int(12e6), int(12.5e6), int(13e6), int(13.5e6)]:
    #start_time = datetime.now()
    dqn_model = utils.stable_dqn3(X_train, y_train, steps, True, 
                                  f'../../models/many_features/0.1/with_correlated_fts/dqn3_by_type_noisy_6_{steps}')
    #end_time = datetime.now()
    #print(f'The duration for {steps} steps is {end_time-start_time}')

using stable baselines 3
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.88     |
|    ep_rew_mean      | -0.86    |
|    exploration_rate | 0.704    |
|    success_rate     | 0.07     |
| time/               |          |
|    episodes         | 100000   |
|    fps              | 828      |
|    time_elapsed     | 356      |
|    total_timesteps  | 295532   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 12.4     |
|    n_updates        | 61382    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.39     |
|    ep_rew_mean      | -0.88    |
|    exploration_rate | 0.341    |
|    success_rate     | 0.1      |
| time/               |          |
|    episodes         | 200000   |
|    fps              | 702      |
|    t

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.52     |
|    ep_rew_mean      | -0.24    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.4      |
| time/               |          |
|    episodes         | 1600000  |
|    fps              | 486      |
|    time_elapsed     | 13667    |
|    total_timesteps  | 6650107  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.15     |
|    n_updates        | 1650026  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.36     |
|    ep_rew_mean      | -0.18    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.42     |
| time/               |          |
|    episodes         | 1700000  |
|    fps              | 487      |
|    time_elapsed     | 14541    |
|    total_timesteps  | 7082107  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.12     |
|    ep_rew_mean      | -0.32    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.35     |
| time/               |          |
|    episodes         | 900000   |
|    fps              | 443      |
|    time_elapsed     | 8074     |
|    total_timesteps  | 3578090  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.096    |
|    n_updates        | 882022   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.65     |
|    ep_rew_mean      | -0.18    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.44     |
| time/               |          |
|    episodes         | 1000000  |
|    fps              | 452      |
|    time_elapsed     | 8828     |
|    total_timesteps  | 3998841  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.25     |
|    ep_rew_mean      | -0.08    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.49     |
| time/               |          |
|    episodes         | 2400000  |
|    fps              | 552      |
|    time_elapsed     | 17616    |
|    total_timesteps  | 9737647  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.155    |
|    n_updates        | 2421911  |
----------------------------------
using stable baselines 3
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.47     |
|    ep_rew_mean      | -0.82    |
|    exploration_rate | 0.736    |
|    success_rate     | 0.1      |
| time/               |          |
|    episodes         | 100000   |
|    fps              | 1044     |
|    t

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.63     |
|    ep_rew_mean      | 0        |
|    exploration_rate | 0.05     |
|    success_rate     | 0.5      |
| time/               |          |
|    episodes         | 1500000  |
|    fps              | 692      |
|    time_elapsed     | 8424     |
|    total_timesteps  | 5831187  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.058    |
|    n_updates        | 1445296  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.75     |
|    ep_rew_mean      | -0.28    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.37     |
| time/               |          |
|    episodes         | 1600000  |
|    fps              | 690      |
|    time_elapsed     | 8999     |
|    total_timesteps  | 6210417  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.42     |
|    ep_rew_mean      | -0.98    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.06     |
| time/               |          |
|    episodes         | 400000   |
|    fps              | 793      |
|    time_elapsed     | 1708     |
|    total_timesteps  | 1355939  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 4.92     |
|    n_updates        | 326484   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.54     |
|    ep_rew_mean      | -0.62    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.22     |
| time/               |          |
|    episodes         | 500000   |
|    fps              | 764      |
|    time_elapsed     | 2371     |
|    total_timesteps  | 1813786  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.8      |
|    ep_rew_mean      | -0.04    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.5      |
| time/               |          |
|    episodes         | 1900000  |
|    fps              | 650      |
|    time_elapsed     | 11142    |
|    total_timesteps  | 7254056  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.156    |
|    n_updates        | 1801013  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.67     |
|    ep_rew_mean      | 0.02     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.53     |
| time/               |          |
|    episodes         | 2000000  |
|    fps              | 633      |
|    time_elapsed     | 12040    |
|    total_timesteps  | 7630876  |
| train/              |          |
|    learning_rate  

#### Testing

In [14]:
# training_env = utils.create_env(X_train, y_train)
# dqn_model = utils.load_dqn3('../../models/many_features/0.1/with_correlated_fts/dqn3_by_type_noisy_4_13500000', training_env)
# test_df = utils.evaluate_dqn(dqn_model, X_test, y_test)
# test_df.head()

Using stable baselines 3
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Count: 2800
Count: 5600
Count: 8400
Count: 11200
Count: 14000
Testing done.....


Unnamed: 0,episode_length,index,is_success,reward,terminated,trajectory,y_actual,y_pred
0,2.0,0.0,1.0,1.0,0.0,"[hemoglobin, No anemia]",0.0,0.0
1,3.0,1.0,1.0,1.0,0.0,"[hemoglobin, gender, No anemia]",0.0,0.0
2,5.0,2.0,1.0,1.0,0.0,"[hemoglobin, rbc, mcv, tibc, Iron deficiency a...",4.0,4.0
3,5.0,3.0,1.0,1.0,0.0,"[hemoglobin, rbc, mcv, segmented_neutrophils, ...",1.0,1.0
4,5.0,4.0,0.0,-1.0,0.0,"[hemoglobin, rbc, mcv, tibc, Anemia of chronic...",6.0,3.0


In [15]:
success_rate, success_df = utils.success_rate(test_df)
success_rate

68.54285714285714

In [16]:
avg_length, avg_return = utils.get_avg_length_reward(test_df)
avg_length, avg_return

(4.8635, 0.25342857142857145)

In [17]:
acc, f1, roc_auc = utils.test(test_df['y_actual'], test_df['y_pred'])
acc, f1, roc_auc

(0.6854285714285714, 0.6763143063439356, 0.8310514175573392)

In [18]:
test_df.y_pred.unique()

array([0., 4., 1., 3., 7., 5., 6., 2.])

In [19]:
test_df[test_df.y_pred==4]

Unnamed: 0,episode_length,index,is_success,reward,terminated,trajectory,y_actual,y_pred
0,4.0,0.0,1.0,1.0,0.0,"[hemoglobin, rbc, ferritin, Iron deficiency an...",4.0,4.0
1,4.0,1.0,0.0,-1.0,0.0,"[hemoglobin, rbc, ferritin, Iron deficiency an...",3.0,4.0
8,5.0,8.0,1.0,1.0,0.0,"[hemoglobin, gender, rbc, ferritin, Iron defic...",4.0,4.0
11,4.0,11.0,1.0,1.0,0.0,"[hemoglobin, rbc, ferritin, Iron deficiency an...",4.0,4.0
36,4.0,36.0,0.0,-1.0,0.0,"[hemoglobin, rbc, ferritin, Iron deficiency an...",3.0,4.0
...,...,...,...,...,...,...,...,...
13966,4.0,13966.0,1.0,1.0,0.0,"[hemoglobin, rbc, ferritin, Iron deficiency an...",4.0,4.0
13970,4.0,13970.0,0.0,-1.0,0.0,"[hemoglobin, rbc, ferritin, Iron deficiency an...",7.0,4.0
13976,4.0,13976.0,1.0,1.0,0.0,"[hemoglobin, rbc, ferritin, Iron deficiency an...",4.0,4.0
13983,5.0,13983.0,1.0,1.0,0.0,"[hemoglobin, gender, rbc, ferritin, Iron defic...",4.0,4.0


#### Saving files

In [75]:
# test_df.to_csv(f'../../test_dfs/many_features/0.1/correlated/test_df3_9000000.csv', index=False)
# success_df.to_csv(f'../../test_dfs/many_features/0.1/correlated/success_df3_9000000.csv', index=False)

#### Confusion matrix and classification report

In [None]:
test_df = pd.read_csv('../../test_dfs/many_features/0.1/test_df3_6500000.csv')
test_df.head()

In [None]:
utils.plot_classification_report(test_df['y_actual'], test_df['y_pred'])

In [None]:
# def plot_confusion_matrix(y_actual, y_pred, save=False, filename=False):
#     from sklearn.metrics import confusion_matrix
#     cm = confusion_matrix(y_actual, y_pred)
#     cm_df = pd.DataFrame(cm, index = [0, 1, 2, 3, 4, 5, 6], columns = [0, 1, 2, 3, 4, 5, 6], dtype='object')
#     #cm_df = pd.DataFrame(cm, index = constants.CLASS_DICT.keys(), columns = constants.CLASS_DICT.keys())
#     plt.figure(figsize=(8, 6))
#     sns.heatmap(cm_df, annot=True)
#     plt.title('Confusion Matrix')
#     plt.ylabel('Actual Anemia')
#     plt.xlabel('Predicted Anemia')
#     plt.tight_layout()
#     if save:
#         plt.savefig(filename)
#     plt.show()
#     plt.close()

In [None]:
utils.plot_confusion_matrix(test_df['y_actual'], test_df['y_pred'])