In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import random
import os
from datetime import datetime
import torch
import sys
sys.path.append('../..')
from modules.many_features import utils, constants
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline

  import pandas.util.testing as tm


In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)

In [3]:
#df= pd.read_csv('../../data/more_features/with_correlated_feature_0.1.csv')
df =pd.read_csv('../../data/more_features//more_feats_correlated_noisy_4.csv')
df = df.fillna(-1)
df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,8.725667,-1.0,2.344526,-1.0,-1.0,96.678418,-1.0,2.707637,1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,26.177001,-1.0,Hemolytic anemia
1,12.430232,231.691743,5.14366,5.339158,484.590922,87.353532,-1.0,4.26894,0,-1.0,-1.0,-1.0,8.718062,-1.0,-1.0,37.290695,-1.0,No anemia
2,6.73776,360.843434,-1.0,5.48114,421.177091,75.086568,82.986344,2.691997,0,0.458698,19.596767,69.527433,-1.0,23.027617,-1.0,20.213279,19.703433,Anemia of chronic disease
3,12.313034,-1.0,2.765581,0.0,-1.0,99.494524,-1.0,3.712677,1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,36.939103,-1.0,Unspecified anemia
4,6.75814,-1.0,1.696525,5.648115,-1.0,86.643128,199.741346,2.339992,0,0.887158,80.614718,109.197814,18.770122,7.668335,63.159054,20.274421,-1.0,Aplastic anemia


In [4]:
utils.get_dt_performance(df)

(0.7016428571428571,
 0.7005761070620158,
 0.8266962007922731,
 datetime.timedelta(microseconds=2951))

In [5]:
df.label.value_counts()

No anemia                               16000
Anemia of chronic disease                8816
Iron deficiency anemia                   8309
Aplastic anemia                          8157
Unspecified anemia                       8132
Vitamin B12/Folate deficiency anemia     8124
Hemolytic anemia                         8057
Inconclusive diagnosis                   4405
Name: label, dtype: int64

In [6]:
class_dict = constants.CLASS_DICT
df['label'] = df['label'].replace(class_dict)
X = df.iloc[:, 0:-1]
y = df.iloc[:, -1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=SEED)
X_train, y_train = np.array(X_train), np.array(y_train)
X_test, y_test = np.array(X_test), np.array(y_test)
X_train.shape, X_test.shape, y_train.shape, y_test.shape

((56000, 17), (14000, 17), (56000,), (14000,))

In [7]:
action_list = list(class_dict.keys()) + [col  for col in df.columns if col!='label']
action_list

['No anemia',
 'Vitamin B12/Folate deficiency anemia',
 'Unspecified anemia',
 'Anemia of chronic disease',
 'Iron deficiency anemia',
 'Hemolytic anemia',
 'Aplastic anemia',
 'Inconclusive diagnosis',
 'hemoglobin',
 'ferritin',
 'ret_count',
 'segmented_neutrophils',
 'tibc',
 'mcv',
 'serum_iron',
 'rbc',
 'gender',
 'creatinine',
 'cholestrol',
 'copper',
 'ethanol',
 'folate',
 'glucose',
 'hematocrit',
 'tsat']

In [None]:
for steps in [int(10e6), int(12e6), int(14e6), int(15e6), int(20e6)]:
#for steps in [int(2e3)]:
    #start_time = datetime.now()
    dqn_model = utils.stable_dqn3(X_train, y_train, steps, True, 
                                  f'../../models/many_features/0.1/with_correlated_fts/dqn3_by_type_noisy_4_{steps}')
    #end_time = datetime.now()
    #print(f'The duration for {steps} steps is {end_time-start_time}')

using stable baselines 3
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.38     |
|    ep_rew_mean      | -0.82    |
|    exploration_rate | 0.722    |
|    success_rate     | 0.13     |
| time/               |          |
|    episodes         | 100000   |
|    fps              | 802      |
|    time_elapsed     | 363      |
|    total_timesteps  | 292146   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 9.41     |
|    n_updates        | 60536    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.26     |
|    ep_rew_mean      | -0.92    |
|    exploration_rate | 0.381    |
|    success_rate     | 0.1      |
| time/               |          |
|    episodes         | 200000   |
|    fps              | 593      |
|    t

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5.1      |
|    ep_rew_mean      | -0.28    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.36     |
| time/               |          |
|    episodes         | 1600000  |
|    fps              | 544      |
|    time_elapsed     | 12079    |
|    total_timesteps  | 6574835  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.114    |
|    n_updates        | 1631208  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.99     |
|    ep_rew_mean      | -0.36    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.36     |
| time/               |          |
|    episodes         | 1700000  |
|    fps              | 550      |
|    time_elapsed     | 12764    |
|    total_timesteps  | 7021512  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.22     |
|    ep_rew_mean      | -0.18    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.46     |
| time/               |          |
|    episodes         | 800000   |
|    fps              | 730      |
|    time_elapsed     | 4270     |
|    total_timesteps  | 3119870  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0275   |
|    n_updates        | 767467   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.5      |
|    ep_rew_mean      | -0.46    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.3      |
| time/               |          |
|    episodes         | 900000   |
|    fps              | 723      |
|    time_elapsed     | 4866     |
|    total_timesteps  | 3520009  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.69     |
|    ep_rew_mean      | -0.12    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.5      |
| time/               |          |
|    episodes         | 2300000  |
|    fps              | 621      |
|    time_elapsed     | 14184    |
|    total_timesteps  | 8817869  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0851   |
|    n_updates        | 2191967  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.47     |
|    ep_rew_mean      | -0.18    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.42     |
| time/               |          |
|    episodes         | 2400000  |
|    fps              | 613      |
|    time_elapsed     | 14972    |
|    total_timesteps  | 9183097  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.84     |
|    ep_rew_mean      | -0.74    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.17     |
| time/               |          |
|    episodes         | 700000   |
|    fps              | 633      |
|    time_elapsed     | 4220     |
|    total_timesteps  | 2674375  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0324   |
|    n_updates        | 656093   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.99     |
|    ep_rew_mean      | -0.48    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.3      |
| time/               |          |
|    episodes         | 800000   |
|    fps              | 591      |
|    time_elapsed     | 5304     |
|    total_timesteps  | 3138013  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.88     |
|    ep_rew_mean      | 0.22     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.68     |
| time/               |          |
|    episodes         | 2200000  |
|    fps              | 591      |
|    time_elapsed     | 14980    |
|    total_timesteps  | 8867871  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0873   |
|    n_updates        | 2204467  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.83     |
|    ep_rew_mean      | -0.46    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.36     |
| time/               |          |
|    episodes         | 2300000  |
|    fps              | 593      |
|    time_elapsed     | 15638    |
|    total_timesteps  | 9282276  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.47     |
|    ep_rew_mean      | -0.96    |
|    exploration_rate | 0.381    |
|    success_rate     | 0.06     |
| time/               |          |
|    episodes         | 300000   |
|    fps              | 1288     |
|    time_elapsed     | 757      |
|    total_timesteps  | 976882   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 117      |
|    n_updates        | 231720   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.33     |
|    ep_rew_mean      | -1       |
|    exploration_rate | 0.151    |
|    success_rate     | 0.06     |
| time/               |          |
|    episodes         | 400000   |
|    fps              | 1249     |
|    time_elapsed     | 1072     |
|    total_timesteps  | 1340315  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.61     |
|    ep_rew_mean      | -0.06    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.47     |
| time/               |          |
|    episodes         | 1800000  |
|    fps              | 1076     |
|    time_elapsed     | 6992     |
|    total_timesteps  | 7524160  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0703   |
|    n_updates        | 1868539  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.37     |
|    ep_rew_mean      | -0.14    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.45     |
| time/               |          |
|    episodes         | 1900000  |
|    fps              | 1073     |
|    time_elapsed     | 7418     |
|    total_timesteps  | 7960480  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.39     |
|    ep_rew_mean      | 0.12     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.61     |
| time/               |          |
|    episodes         | 3300000  |
|    fps              | 1044     |
|    time_elapsed     | 13492    |
|    total_timesteps  | 14096502 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.161    |
|    n_updates        | 3511625  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.36     |
|    ep_rew_mean      | 0.1      |
|    exploration_rate | 0.05     |
|    success_rate     | 0.63     |
| time/               |          |
|    episodes         | 3400000  |
|    fps              | 1043     |
|    time_elapsed     | 13920    |
|    total_timesteps  | 14521667 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5.08     |
|    ep_rew_mean      | -0.44    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.32     |
| time/               |          |
|    episodes         | 1300000  |
|    fps              | 1115     |
|    time_elapsed     | 4767     |
|    total_timesteps  | 5318680  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0341   |
|    n_updates        | 1317169  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.33     |
|    ep_rew_mean      | -0.44    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.32     |
| time/               |          |
|    episodes         | 1400000  |
|    fps              | 1107     |
|    time_elapsed     | 5217     |
|    total_timesteps  | 5777946  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.65     |
|    ep_rew_mean      | -0.04    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.52     |
| time/               |          |
|    episodes         | 2800000  |
|    fps              | 747      |
|    time_elapsed     | 16089    |
|    total_timesteps  | 12033225 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.188    |
|    n_updates        | 2995806  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.62     |
|    ep_rew_mean      | -0.06    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.51     |
| time/               |          |
|    episodes         | 2900000  |
|    fps              | 739      |
|    time_elapsed     | 16891    |
|    total_timesteps  | 12494440 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.42     |
|    ep_rew_mean      | 0.18     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.63     |
| time/               |          |
|    episodes         | 4300000  |
|    fps              | 671      |
|    time_elapsed     | 28192    |
|    total_timesteps  | 18945259 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0551   |
|    n_updates        | 4723814  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.97     |
|    ep_rew_mean      | 0.26     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.67     |
| time/               |          |
|    episodes         | 4400000  |
|    fps              | 668      |
|    time_elapsed     | 29043    |
|    total_timesteps  | 19413742 |
| train/              |          |
|    learning_rate  

In [37]:
# training_env = utils.create_env(X_train, y_train)
# dqn_model = utils.load_dqn3('../../models/many_features/0.1/dqn3_by_type_new_labels_noisy_6_16000000', training_env)
# test_df = utils.evaluate_dqn(dqn_model, X_test, y_test)
# test_df.head()

Using stable baselines 3
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Count: 2800
Count: 5600
Count: 8400
Count: 11200
Count: 14000
Testing done.....


Unnamed: 0,episode_length,index,is_success,reward,terminated,trajectory,y_actual,y_pred
0,5.0,0.0,0.0,-1.0,1.0,"[hemoglobin, mcv, gender, tibc, tibc]",6.0,7.0
1,3.0,1.0,1.0,1.0,0.0,"[hemoglobin, mcv, Unspecified anemia]",2.0,2.0
2,4.0,2.0,1.0,1.0,0.0,"[hemoglobin, mcv, segmented_neutrophils, Unspe...",2.0,2.0
3,4.0,3.0,0.0,-1.0,0.0,"[hemoglobin, mcv, tibc, Anemia of chronic dise...",5.0,3.0
4,4.0,4.0,0.0,-1.0,0.0,"[hemoglobin, mcv, gender, Aplastic anemia]",5.0,6.0


In [38]:
#utils.diagnose_sample(dqn_model, X_test, y_test, 1)

In [39]:
# test_df[(test_df.y_pred==1) & (test_df.y_actual==1)]

In [40]:
test_df.y_pred.value_counts()

7.0    2659
3.0    2445
4.0    1956
0.0    1939
1.0    1806
2.0    1643
6.0    1552
Name: y_pred, dtype: int64

In [41]:
success_rate, success_df = utils.success_rate(test_df)
success_rate

67.36428571428571

In [42]:
avg_length, avg_return = utils.get_avg_length_reward(test_df)
avg_length, avg_return

(3.994285714285714, 0.25442857142857145)

In [43]:
acc, f1, roc_auc = utils.test(test_df['y_actual'], test_df['y_pred'])
acc, f1, roc_auc

(0.6736428571428571, 0.62354652826838, 0.8100958870169724)

In [44]:
test_df.y_pred.unique()

array([7., 2., 3., 6., 0., 4., 1.])

In [26]:
# test_df.to_csv(f'../../test_dfs/many_features/0.1/test_df3_noisy_1_11000000.csv', index=False)
# success_df.to_csv(f'../../test_dfs/many_features/0.1/success_df3_noisy_1_11000000.csv', index=False)