In [2]:
import pandas as pd
import numpy as np
import seaborn as sns
from datetime import datetime
import random
import os
import torch
import sys
sys.path.append('../..')
from modules.many_features import utils, constants
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

In [3]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)

#### The Data

In [4]:
#df = pd.read_csv('../../data/anemia_synth_dataset_some_nans_unspecified_more_feats.csv')
#df = pd.read_csv('../../data/more_feats_0.2.csv')
#df= pd.read_csv('../../data/more_features/more_feats_new_labels_0.1.csv')
df =pd.read_csv('../../data/more_features/more_feats_new_labels_0.1_noisy_0.6.csv')
#df = pd.read_csv('../../data/more_features/more_feats_0.3.csv')
#df = utils.balance_dataset(df, 8000)
#df = df.fillna(-1)
df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,age,gender,indirect_bilirubin,transferrin,creatinine,cholestrol,copper,ethanol,folate,glucose,label
0,12.19083,163.097819,3.084127,5.923027,232.339305,79.55047,-1.0,3.748689,-1.0,1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,Anemia of chronic disease
1,9.944819,187.086684,5.096256,-1.0,-1.0,95.125054,-1.0,5.857746,72.837294,0,1.146823,-1.0,-1.0,-1.0,-1.0,20.943783,-1.0,-1.0,Unspecified anemia
2,13.703345,-1.0,-1.0,-1.0,-1.0,-1.0,80.372015,5.684361,45.702318,0,1.789854,206.817706,1.551467,90.217875,76.009442,22.651663,12.551513,-1.0,No anemia
3,7.346123,27.669632,-1.0,-1.0,428.089083,77.375356,95.912445,6.469686,52.594562,0,0.22208,297.319109,1.265341,126.999491,46.998397,77.183582,8.518821,119.113878,Anemia of chronic disease
4,12.295548,225.097199,3.827717,0.0,-1.0,102.137453,-1.0,-1.0,-1.0,1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,Unspecified anemia


In [5]:
df.isna().sum()

hemoglobin               0
ferritin                 0
ret_count                0
segmented_neutrophils    0
tibc                     0
mcv                      0
serum_iron               0
rbc                      0
age                      0
gender                   0
indirect_bilirubin       0
transferrin              0
creatinine               0
cholestrol               0
copper                   0
ethanol                  0
folate                   0
glucose                  0
label                    0
dtype: int64

In [6]:
utils.get_dt_performance(df)

(0.9402857142857143,
 0.9273100518851077,
 0.9593184226883218,
 datetime.timedelta(microseconds=3001))

In [7]:
df.label.value_counts()

No anemia                               10000
Anemia of chronic disease                9756
Iron deficiency anemia                   9267
Unspecified anemia                       9033
Aplastic anemia                          9020
Vitamin B12/Folate deficiency anemia     9000
Hemolytic anemia                         8976
Inconclusive diagnosis                   4948
Name: label, dtype: int64

In [8]:
class_dict = constants.CLASS_DICT
df['label'] = df['label'].replace(class_dict)
X = df.iloc[:, 0:-1]
y = df.iloc[:, -1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=SEED)
X_train, y_train = np.array(X_train), np.array(y_train)
X_test, y_test = np.array(X_test), np.array(y_test)
X_train.shape, X_test.shape, y_train.shape, y_test.shape

((56000, 18), (14000, 18), (56000,), (14000,))

In [9]:
X_test[0]

array([ 10.34995981,  -1.        ,   0.99855381,  -1.        ,
        -1.        ,  86.411147  , 111.70748795,   4.94433984,
        27.50298959,   0.        ,   0.27929003, 157.7912435 ,
         0.72896824,  19.37271239,  51.43928931,  43.59894376,
         7.55852452,  64.72339742])

In [10]:
action_list = list(class_dict.keys()) + [col  for col in df.columns if col!='label']
action_list

['No anemia',
 'Vitamin B12/Folate deficiency anemia',
 'Unspecified anemia',
 'Anemia of chronic disease',
 'Iron deficiency anemia',
 'Hemolytic anemia',
 'Aplastic anemia',
 'Inconclusive diagnosis',
 'hemoglobin',
 'ferritin',
 'ret_count',
 'segmented_neutrophils',
 'tibc',
 'mcv',
 'serum_iron',
 'rbc',
 'age',
 'gender',
 'indirect_bilirubin',
 'transferrin',
 'creatinine',
 'cholestrol',
 'copper',
 'ethanol',
 'folate',
 'glucose']

In [11]:
len(action_list)

26

In [12]:
df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,age,gender,indirect_bilirubin,transferrin,creatinine,cholestrol,copper,ethanol,folate,glucose,label
0,12.19083,163.097819,3.084127,5.923027,232.339305,79.55047,-1.0,3.748689,-1.0,1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,3
1,9.944819,187.086684,5.096256,-1.0,-1.0,95.125054,-1.0,5.857746,72.837294,0,1.146823,-1.0,-1.0,-1.0,-1.0,20.943783,-1.0,-1.0,2
2,13.703345,-1.0,-1.0,-1.0,-1.0,-1.0,80.372015,5.684361,45.702318,0,1.789854,206.817706,1.551467,90.217875,76.009442,22.651663,12.551513,-1.0,0
3,7.346123,27.669632,-1.0,-1.0,428.089083,77.375356,95.912445,6.469686,52.594562,0,0.22208,297.319109,1.265341,126.999491,46.998397,77.183582,8.518821,119.113878,3
4,12.295548,225.097199,3.827717,0.0,-1.0,102.137453,-1.0,-1.0,-1.0,1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2


#### Training 

In [13]:
# %%time
# timesteps = int(2e6)
# dqn_model = utils.stable_dqn3(X_train, y_train, timesteps, True, f'../../models/many_features/stable_dqn3_{timesteps}')
# test_df = utils.evaluate_dqn(dqn_model, X_test, y_test)
# test_df.head()

In [15]:
#for steps in [int(6e6), int(6.5e6), int(7e6), int(7.5e6), int(8e6), int(8.5e6), int(9e6)]:
for steps in [int(16e6), int(18e6)]:
    #start_time = datetime.now()
    dqn_model = utils.stable_dqn3(X_train, y_train, steps, True, 
                                  f'../../models/many_features/0.1/dqn3_by_type_new_labels_noisy_6_{steps}')
    #end_time = datetime.now()
    #print(f'The duration for {steps} steps is {end_time-start_time}')

using stable baselines 3
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.87     |
|    ep_rew_mean      | -0.84    |
|    exploration_rate | 0.827    |
|    success_rate     | 0.09     |
| time/               |          |
|    episodes         | 100000   |
|    fps              | 611      |
|    time_elapsed     | 476      |
|    total_timesteps  | 291838   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 4.11     |
|    n_updates        | 60459    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.51     |
|    ep_rew_mean      | -0.78    |
|    exploration_rate | 0.633    |
|    success_rate     | 0.12     |
| time/               |          |
|    episodes         | 200000   |
|    fps              | 511      |
|    t

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.09     |
|    ep_rew_mean      | -0.06    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.48     |
| time/               |          |
|    episodes         | 1600000  |
|    fps              | 503      |
|    time_elapsed     | 11987    |
|    total_timesteps  | 6040719  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.184    |
|    n_updates        | 1497679  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.68     |
|    ep_rew_mean      | 0.12     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.59     |
| time/               |          |
|    episodes         | 1700000  |
|    fps              | 514      |
|    time_elapsed     | 12519    |
|    total_timesteps  | 6437765  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.03     |
|    ep_rew_mean      | 0.02     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.52     |
| time/               |          |
|    episodes         | 3100000  |
|    fps              | 610      |
|    time_elapsed     | 19596    |
|    total_timesteps  | 11964175 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.118    |
|    n_updates        | 2978543  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.9      |
|    ep_rew_mean      | 0.06     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.55     |
| time/               |          |
|    episodes         | 3200000  |
|    fps              | 616      |
|    time_elapsed     | 20038    |
|    total_timesteps  | 12359179 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.65     |
|    ep_rew_mean      | -0.98    |
|    exploration_rate | 0.119    |
|    success_rate     | 0.03     |
| time/               |          |
|    episodes         | 500000   |
|    fps              | 614      |
|    time_elapsed     | 2713     |
|    total_timesteps  | 1668420  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 10.3     |
|    n_updates        | 404604   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.82     |
|    ep_rew_mean      | -1       |
|    exploration_rate | 0.05     |
|    success_rate     | 0.11     |
| time/               |          |
|    episodes         | 600000   |
|    fps              | 664      |
|    time_elapsed     | 3130     |
|    total_timesteps  | 2080506  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.15     |
|    ep_rew_mean      | -0.08    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.47     |
| time/               |          |
|    episodes         | 2000000  |
|    fps              | 565      |
|    time_elapsed     | 13348    |
|    total_timesteps  | 7549192  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.163    |
|    n_updates        | 1874797  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.01     |
|    ep_rew_mean      | -0.02    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.51     |
| time/               |          |
|    episodes         | 2100000  |
|    fps              | 571      |
|    time_elapsed     | 13891    |
|    total_timesteps  | 7944389  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.95     |
|    ep_rew_mean      | 0.06     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.6      |
| time/               |          |
|    episodes         | 3500000  |
|    fps              | 530      |
|    time_elapsed     | 25706    |
|    total_timesteps  | 13645498 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0439   |
|    n_updates        | 3398874  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.19     |
|    ep_rew_mean      | 0.14     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.62     |
| time/               |          |
|    episodes         | 3600000  |
|    fps              | 533      |
|    time_elapsed     | 26352    |
|    total_timesteps  | 14065406 |
| train/              |          |
|    learning_rate  

#### Testing

In [26]:
training_env = utils.create_env(X_train, y_train)
dqn_model = utils.load_dqn3('../../models/many_features/0.1/dqn3_by_type_new_labels_noisy_6_12000000', training_env)
test_df = utils.evaluate_dqn(dqn_model, X_test, y_test)
test_df.head()

Using stable baselines 3
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Count: 2800
Count: 5600
Count: 8400
Count: 11200
Count: 14000
Testing done.....


Unnamed: 0,episode_length,index,is_success,reward,terminated,trajectory,y_actual,y_pred
0,4.0,0.0,1.0,1.0,0.0,"[hemoglobin, mcv, ret_count, Aplastic anemia]",6.0,6.0
1,3.0,1.0,1.0,1.0,0.0,"[hemoglobin, mcv, Unspecified anemia]",2.0,2.0
2,5.0,2.0,1.0,1.0,0.0,"[hemoglobin, mcv, segmented_neutrophils, gende...",2.0,2.0
3,5.0,3.0,0.0,-1.0,0.0,"[hemoglobin, mcv, tibc, gender, Anemia of chro...",5.0,3.0
4,4.0,4.0,1.0,1.0,0.0,"[hemoglobin, mcv, ret_count, Hemolytic anemia]",5.0,5.0


In [27]:
test_df.y_pred.value_counts()

7.0    3059
3.0    2288
0.0    1998
1.0    1787
2.0    1584
4.0    1505
6.0    1016
5.0     763
Name: y_pred, dtype: int64

In [28]:
success_rate, success_df = utils.success_rate(test_df)
success_rate

72.02857142857142

In [29]:
avg_length, avg_return = utils.get_avg_length_reward(test_df)
avg_length, avg_return

(4.072071428571428, 0.31342857142857145)

In [30]:
acc, f1, roc_auc = utils.test(test_df['y_actual'], test_df['y_pred'])
acc, f1, roc_auc

(0.7202857142857143, 0.7108547931825261, 0.843465232085614)

In [31]:
test_df.y_pred.unique()

array([6., 2., 3., 5., 0., 4., 1., 7.])

In [32]:
test_df[test_df.y_pred==4]

Unnamed: 0,episode_length,index,is_success,reward,terminated,trajectory,y_actual,y_pred
8,4.0,8.0,1.0,1.0,0.0,"[hemoglobin, mcv, tibc, Iron deficiency anemia]",4.0,4.0
16,4.0,16.0,1.0,1.0,0.0,"[hemoglobin, mcv, tibc, Iron deficiency anemia]",4.0,4.0
23,5.0,23.0,1.0,1.0,0.0,"[hemoglobin, mcv, tibc, ferritin, Iron deficie...",4.0,4.0
36,4.0,36.0,0.0,-1.0,0.0,"[hemoglobin, mcv, tibc, Iron deficiency anemia]",5.0,4.0
44,5.0,44.0,1.0,1.0,0.0,"[hemoglobin, gender, mcv, tibc, Iron deficienc...",4.0,4.0
...,...,...,...,...,...,...,...,...
13948,4.0,13948.0,1.0,1.0,0.0,"[hemoglobin, mcv, tibc, Iron deficiency anemia]",4.0,4.0
13968,6.0,13968.0,1.0,1.0,0.0,"[hemoglobin, gender, mcv, ferritin, tibc, Iron...",4.0,4.0
13989,6.0,13989.0,1.0,1.0,0.0,"[hemoglobin, mcv, tibc, ferritin, gender, Iron...",4.0,4.0
13996,6.0,13996.0,1.0,1.0,0.0,"[hemoglobin, mcv, ferritin, tibc, gender, Iron...",4.0,4.0


#### Saving files

In [33]:
test_df.to_csv(f'../../test_dfs/many_features/0.1/test_df3_noisy_6_12000000.csv', index=False)
success_df.to_csv(f'../../test_dfs/many_features/0.1/success_df3_noisy_6_12000000.csv', index=False)

#### Confusion matrix and classification report

In [None]:
test_df = pd.read_csv('../../test_dfs/many_features/0.1/test_df3_6500000.csv')
test_df.head()

In [None]:
utils.plot_classification_report(test_df['y_actual'], test_df['y_pred'])

In [None]:
# def plot_confusion_matrix(y_actual, y_pred, save=False, filename=False):
#     from sklearn.metrics import confusion_matrix
#     cm = confusion_matrix(y_actual, y_pred)
#     cm_df = pd.DataFrame(cm, index = [0, 1, 2, 3, 4, 5, 6], columns = [0, 1, 2, 3, 4, 5, 6], dtype='object')
#     #cm_df = pd.DataFrame(cm, index = constants.CLASS_DICT.keys(), columns = constants.CLASS_DICT.keys())
#     plt.figure(figsize=(8, 6))
#     sns.heatmap(cm_df, annot=True)
#     plt.title('Confusion Matrix')
#     plt.ylabel('Actual Anemia')
#     plt.xlabel('Predicted Anemia')
#     plt.tight_layout()
#     if save:
#         plt.savefig(filename)
#     plt.show()
#     plt.close()

In [None]:
utils.plot_confusion_matrix(test_df['y_actual'], test_df['y_pred'])