In [12]:
import pandas as pd
import numpy as np
import random
import os
import torch
import sys
sys.path.append('../..')
from modules.many_features import utils, constants
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

In [13]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)

#### The training data

In [14]:
# train_df = pd.read_csv('../../final/data/new_experiments/train_set_basic.csv')
#train_df = pd.read_csv('../../final/data/train_set_noisy_6_missing_3.csv')
train_df = pd.read_csv('../../final/data/train_set_basic.csv')
train_df = train_df.fillna(-1)
train_df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,9.007012,-1.0,-1.0,3.519565,440.499323,103.442762,59.017997,2.612173,1,0.650757,114.794964,112.308159,25.612786,5.96971,116.026042,27.021037,13.397977,1
1,8.760976,-1.0,0.491469,-1.0,259.895852,103.885481,-1.0,2.529991,0,0.728641,74.824352,-1.0,-1.0,-1.0,-1.0,26.282929,-1.0,7
2,7.490324,70.812609,-1.0,1.495604,482.109919,79.543391,-1.0,2.824995,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,22.470972,-1.0,4
3,8.11337,499.313483,3.507612,0.0,195.351883,100.545858,203.895756,2.420797,1,0.860863,9.120947,41.131511,1.84389,18.845473,106.199806,24.340111,104.373581,2
4,13.935301,349.569415,5.190725,6.894195,489.595939,102.234294,150.085853,4.089225,1,0.216907,20.344863,92.547095,19.815123,29.543875,98.38871,41.805903,30.655044,0


In [15]:
len(train_df)

56000

In [16]:
train_df.isna().sum()

hemoglobin               0
ferritin                 0
ret_count                0
segmented_neutrophils    0
tibc                     0
mcv                      0
serum_iron               0
rbc                      0
gender                   0
creatinine               0
cholestrol               0
copper                   0
ethanol                  0
folate                   0
glucose                  0
hematocrit               0
tsat                     0
label                    0
dtype: int64

In [17]:
X_train = train_df.iloc[:, 0:-1]
y_train = train_df.iloc[:, -1]
X_train, y_train = np.array(X_train), np.array(y_train)
X_train.shape, y_train.shape

((56000, 17), (56000,))

In [18]:
action_list = list(constants.CLASS_DICT.keys()) + [col  for col in train_df.columns if col!='label']
action_list

['No anemia',
 'Vitamin B12/Folate deficiency anemia',
 'Unspecified anemia',
 'Anemia of chronic disease',
 'Iron deficiency anemia',
 'Hemolytic anemia',
 'Aplastic anemia',
 'Inconclusive diagnosis',
 'hemoglobin',
 'ferritin',
 'ret_count',
 'segmented_neutrophils',
 'tibc',
 'mcv',
 'serum_iron',
 'rbc',
 'gender',
 'creatinine',
 'cholestrol',
 'copper',
 'ethanol',
 'folate',
 'glucose',
 'hematocrit',
 'tsat']

#### The Model

In [19]:
from stable_baselines3 import DQN

In [20]:
training_env = utils.create_env(X_train, y_train)
model = DQN('MlpPolicy', training_env, verbose=1, seed=constants.SEED)
checkpoint_callback = CheckpointCallback(save_freq=500000, 
                                         save_path='../../final/models/new_experiments/logs/trial_3',
                                         name_prefix='dqn_basic')
model.learn(total_timesteps=50000000, log_interval=500000, callback=checkpoint_callback)
model.save('../../final/models/new_experiments/dqn_basic_anemia_diagnosis_trial_3')
training_env.close()

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.01     |
|    ep_rew_mean      | -0.84    |
|    exploration_rate | 0.729    |
|    success_rate     | 0.09     |
| time/               |          |
|    episodes         | 500000   |
|    fps              | 2996     |
|    time_elapsed     | 476      |
|    total_timesteps  | 1426906  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 5.54e+06 |
|    n_updates        | 344226   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.88     |
|    ep_rew_mean      | -1       |
|    exploration_rate | 0.457    |
|    success_rate     | 0.08     |
| time/               |          |
|    episodes         | 1000000  |
|    fps              | 2826     |
|    time_elapsed     | 1011   

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.01     |
|    ep_rew_mean      | -0.56    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.22     |
| time/               |          |
|    episodes         | 8000000  |
|    fps              | 1804     |
|    time_elapsed     | 7867     |
|    total_timesteps  | 14192573 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.178    |
|    n_updates        | 3535643  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.04     |
|    ep_rew_mean      | -1       |
|    exploration_rate | 0.05     |
|    success_rate     | 0.1      |
| time/               |          |
|    episodes         | 8500000  |
|    fps              | 1775     |
|    time_elapsed     | 8289     |
|    total_timesteps  | 14716302 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.05     |
|    ep_rew_mean      | -0.7     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.15     |
| time/               |          |
|    episodes         | 15500000 |
|    fps              | 1551     |
|    time_elapsed     | 14205    |
|    total_timesteps  | 22034199 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.179    |
|    n_updates        | 5496049  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.02     |
|    ep_rew_mean      | -0.72    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.14     |
| time/               |          |
|    episodes         | 16000000 |
|    fps              | 1543     |
|    time_elapsed     | 14613    |
|    total_timesteps  | 22555290 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.03     |
|    ep_rew_mean      | -0.72    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.14     |
| time/               |          |
|    episodes         | 23000000 |
|    fps              | 1479     |
|    time_elapsed     | 20194    |
|    total_timesteps  | 29876156 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0532   |
|    n_updates        | 7456538  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.74    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.13     |
| time/               |          |
|    episodes         | 23500000 |
|    fps              | 1473     |
|    time_elapsed     | 20630    |
|    total_timesteps  | 30400547 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.03     |
|    ep_rew_mean      | -0.82    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.09     |
| time/               |          |
|    episodes         | 30500000 |
|    fps              | 1410     |
|    time_elapsed     | 26743    |
|    total_timesteps  | 37716459 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.22     |
|    n_updates        | 9416614  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.02     |
|    ep_rew_mean      | -0.68    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.16     |
| time/               |          |
|    episodes         | 31000000 |
|    fps              | 1414     |
|    time_elapsed     | 27025    |
|    total_timesteps  | 38239698 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.05     |
|    ep_rew_mean      | -0.62    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.19     |
| time/               |          |
|    episodes         | 38000000 |
|    fps              | 1453     |
|    time_elapsed     | 31352    |
|    total_timesteps  | 45555924 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.136    |
|    n_updates        | 11376480 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.02     |
|    ep_rew_mean      | -0.7     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.15     |
| time/               |          |
|    episodes         | 38500000 |
|    fps              | 1454     |
|    time_elapsed     | 31674    |
|    total_timesteps  | 46078096 |
| train/              |          |
|    learning_rate  

#### Original trial

In [None]:
# from stable_baselines3 import DQN
from stable_baselines import DQN
# from stable_baselines3 import bench, logger
from modules.many_features.env import SyntheticEnv
from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines.common.callbacks import CheckpointCallback

In [None]:
training_env = SyntheticEnv(X_train, y_train)
training_env = DummyVecEnv([lambda: training_env])
# training_env = VecNormalize(training_env, norm_obs=True, norm_reward=False, clip_obs=10.)

# Define and train the DQN agent
model = DQN('MlpPolicy', training_env, verbose=1, seed = constants.SEED)
checkpoint_callback = CheckpointCallback(save_freq=100000, 
                                         save_path='../../final/models/new_experiments/logs/basic_sb',
                                         name_prefix='dqn_sb_basic')

model.learn(total_timesteps=20000000, log_interval=100000, callback=checkpoint_callback)

# Save the trained DQN agent
model.save('../../final/models/new_experiments/dqn_basic_anemia_diagnosis_sb')
training_env.close()