In [1]:
import pandas as pd
import numpy as np
from datetime import datetime
import random
import os
import torch
import sys
sys.path.append('../')
from lillian_modules import utils, constants
from lillian_modules.env import AnemiaEnv
from stable_baselines3.dqn import DQN
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

In [2]:
SEED = constants.SEED
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)

#### The data

In [3]:
training_df = pd.read_csv('../data/training_noisiness_0.1.csv')
validation_df = pd.read_csv('../data/validation_set.csv')
testing_df = pd.read_csv('../data/testing_set.csv')
training_df = training_df.fillna(-1)
X_train, y_train = training_df.iloc[:, 0:-1], training_df.iloc[:, -1]
X_val, y_val = validation_df.iloc[:, 0:-1], validation_df.iloc[:, -1]
X_test, y_test = testing_df.iloc[:, 0:-1], testing_df.iloc[:, -1]
X_train, y_train = np.array(X_train), np.array(y_train)
X_val, y_val = np.array(X_val), np.array(y_val)
X_test, y_test = np.array(X_test), np.array(y_test)
X_train.shape, X_val.shape, X_test.shape, y_train.shape, y_val.shape, y_test.shape

((50400, 17), (5600, 17), (14000, 17), (50400,), (5600,), (14000,))

In [4]:
action_list = list(constants.CLASS_DICT.keys()) + [col  for col in training_df.columns if col!='label']
action_list

['No anemia',
 'Vitamin B12/Folate deficiency anemia',
 'Unspecified anemia',
 'Anemia of chronic disease',
 'Iron deficiency anemia',
 'Hemolytic anemia',
 'Aplastic anemia',
 'Inconclusive diagnosis',
 'hemoglobin',
 'ferritin',
 'ret_count',
 'segmented_neutrophils',
 'tibc',
 'mcv',
 'serum_iron',
 'rbc',
 'gender',
 'creatinine',
 'cholestrol',
 'copper',
 'ethanol',
 'folate',
 'glucose',
 'hematocrit',
 'tsat']

#### Training

In [5]:
def stable_dqn3(X_train, y_train, timesteps, validation_times=10, folder_name=None):
    training_env = AnemiaEnv(X_train, y_train)
    print(f'ACTION SPACE: {training_env.actions}')
    model = DQN(policy='MlpPolicy', env=training_env, verbose=1, seed=SEED, X_val=X_val, y_val=y_val, 
                to_save_folder=folder_name)
    model.learn(total_timesteps=timesteps, validation_times=validation_times, log_interval=100000)
#     if save:
#         model.save(filename)
    training_env.close()
    return model

In [6]:
ft_num = 17
steps = 30000000

# for steps in [int(6e6)]:
# for steps in [int()]:
model = stable_dqn3(X_train, y_train, timesteps=steps, validation_times=300, folder_name=f'../models/noisiness_0.1')
# val_df.to_csv(f'../test_dfs/val_df_ft_{ft_num}_{steps}', index=False)
#     stable_dqn3(X_train, y_train, steps)

ACTION SPACE: ['No anemia', 'Vitamin B12/Folate deficiency anemia', 'Unspecified anemia', 'Anemia of chronic disease', 'Iron deficiency anemia', 'Hemolytic anemia', 'Aplastic anemia', 'Inconclusive diagnosis', 'hemoglobin', 'ferritin', 'ret_count', 'segmented_neutrophils', 'tibc', 'mcv', 'serum_iron', 'rbc', 'gender', 'creatinine', 'cholestrol', 'copper', 'ethanol', 'folate', 'glucose', 'hematocrit', 'tsat']
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
setting up model
Validating at 100000/30000000
Validating at 200000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.1      |
|    ep_rew_mean      | -0.82    |
|    exploration_rate | 0.913    |
|    success_rate     | 0.12     |
| time/               |          |
|    episodes         | 100000   |
|    fps              | 2721     |
|    time_elapsed     | 100      |
|    total_timesteps  | 274696   |
| train/              |         

Validating at 3300000/30000000
Validating at 3400000/30000000
Validating at 3500000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.5      |
|    ep_rew_mean      | -1       |
|    exploration_rate | 0.05     |
|    success_rate     | 0.1      |
| time/               |          |
|    episodes         | 1300000  |
|    fps              | 2042     |
|    time_elapsed     | 1745     |
|    total_timesteps  | 3564210  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 4.38e+06 |
|    n_updates        | 878552   |
----------------------------------
Validating at 3600000/30000000
Validating at 3700000/30000000
Validating at 3800000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.23     |
|    ep_rew_mean      | -1       |
|    exploration_rate | 0.05     |
|    success_rate     | 0.08     |
| time/               |          |
|    epis

Validating at 6500000/30000000
Validating at 6600000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.07     |
|    ep_rew_mean      | -1       |
|    exploration_rate | 0.05     |
|    success_rate     | 0.09     |
| time/               |          |
|    episodes         | 2600000  |
|    fps              | 1918     |
|    time_elapsed     | 3446     |
|    total_timesteps  | 6611889  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.168    |
|    n_updates        | 1640472  |
----------------------------------
Validating at 6700000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1        |
|    ep_rew_mean      | -0.44    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.28     |
| time/               |          |
|    episodes         | 2700000  |
|    fps              | 1913     |
|    time_elapsed     | 3536    

Validating at 8100000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.05     |
|    ep_rew_mean      | -0.46    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.27     |
| time/               |          |
|    episodes         | 4000000  |
|    fps              | 1879     |
|    time_elapsed     | 4324     |
|    total_timesteps  | 8128422  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.405    |
|    n_updates        | 2019605  |
----------------------------------
Validating at 8200000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.56    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.22     |
| time/               |          |
|    episodes         | 4100000  |
|    fps              | 1878     |
|    time_elapsed     | 4383     |
|    total_timesteps  | 8233

Validating at 9500000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.03     |
|    ep_rew_mean      | -0.62    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.19     |
| time/               |          |
|    episodes         | 5400000  |
|    fps              | 1790     |
|    time_elapsed     | 5355     |
|    total_timesteps  | 9587561  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.15     |
|    n_updates        | 2384390  |
----------------------------------
Validating at 9600000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.56    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.22     |
| time/               |          |
|    episodes         | 5500000  |
|    fps              | 1790     |
|    time_elapsed     | 5413     |
|    total_timesteps  | 9692

Validating at 11000000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.62    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.19     |
| time/               |          |
|    episodes         | 6800000  |
|    fps              | 1787     |
|    time_elapsed     | 6179     |
|    total_timesteps  | 11045864 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.185    |
|    n_updates        | 2748965  |
----------------------------------
Validating at 11100000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.02     |
|    ep_rew_mean      | -0.68    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.16     |
| time/               |          |
|    episodes         | 6900000  |
|    fps              | 1787     |
|    time_elapsed     | 6237     |
|    total_timesteps  | 11

Validating at 12500000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.52    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.24     |
| time/               |          |
|    episodes         | 8200000  |
|    fps              | 1785     |
|    time_elapsed     | 7005     |
|    total_timesteps  | 12506791 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.259    |
|    n_updates        | 3114197  |
----------------------------------
Validating at 12600000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.09     |
|    ep_rew_mean      | -0.66    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.17     |
| time/               |          |
|    episodes         | 8300000  |
|    fps              | 1785     |
|    time_elapsed     | 7063     |
|    total_timesteps  | 12

Validating at 13900000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.07     |
|    ep_rew_mean      | -0.54    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.23     |
| time/               |          |
|    episodes         | 9600000  |
|    fps              | 1785     |
|    time_elapsed     | 7822     |
|    total_timesteps  | 13966554 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.333    |
|    n_updates        | 3479138  |
----------------------------------
Validating at 14000000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.02     |
|    ep_rew_mean      | -0.64    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.18     |
| time/               |          |
|    episodes         | 9700000  |
|    fps              | 1785     |
|    time_elapsed     | 7881     |
|    total_timesteps  | 14

Validating at 15400000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.03     |
|    ep_rew_mean      | -0.5     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.25     |
| time/               |          |
|    episodes         | 11000000 |
|    fps              | 1784     |
|    time_elapsed     | 8644     |
|    total_timesteps  | 15425629 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.366    |
|    n_updates        | 3843907  |
----------------------------------
Validating at 15500000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.06     |
|    ep_rew_mean      | -0.48    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.26     |
| time/               |          |
|    episodes         | 11100000 |
|    fps              | 1784     |
|    time_elapsed     | 8702     |
|    total_timesteps  | 15

Validating at 16800000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.64    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.18     |
| time/               |          |
|    episodes         | 12400000 |
|    fps              | 1412     |
|    time_elapsed     | 11955    |
|    total_timesteps  | 16886539 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.222    |
|    n_updates        | 4209134  |
----------------------------------
Validating at 16900000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.03     |
|    ep_rew_mean      | -0.54    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.23     |
| time/               |          |
|    episodes         | 12500000 |
|    fps              | 1412     |
|    time_elapsed     | 12026    |
|    total_timesteps  | 16

Validating at 18300000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.03     |
|    ep_rew_mean      | -0.58    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.21     |
| time/               |          |
|    episodes         | 13800000 |
|    fps              | 1420     |
|    time_elapsed     | 12915    |
|    total_timesteps  | 18345307 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.22     |
|    n_updates        | 4573826  |
----------------------------------
Validating at 18400000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.03     |
|    ep_rew_mean      | -0.6     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.2      |
| time/               |          |
|    episodes         | 13900000 |
|    fps              | 1420     |
|    time_elapsed     | 12983    |
|    total_timesteps  | 18

Validating at 19800000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.48    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.26     |
| time/               |          |
|    episodes         | 15200000 |
|    fps              | 1422     |
|    time_elapsed     | 13923    |
|    total_timesteps  | 19805888 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.152    |
|    n_updates        | 4938971  |
----------------------------------
Validating at 19900000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1        |
|    ep_rew_mean      | -0.48    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.26     |
| time/               |          |
|    episodes         | 15300000 |
|    fps              | 1422     |
|    time_elapsed     | 13993    |
|    total_timesteps  | 19

Validating at 21200000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.02     |
|    ep_rew_mean      | -0.64    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.18     |
| time/               |          |
|    episodes         | 16600000 |
|    fps              | 1427     |
|    time_elapsed     | 14899    |
|    total_timesteps  | 21265237 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.224    |
|    n_updates        | 5303809  |
----------------------------------
Validating at 21300000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.48    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.26     |
| time/               |          |
|    episodes         | 16700000 |
|    fps              | 1427     |
|    time_elapsed     | 14966    |
|    total_timesteps  | 21

Validating at 22700000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.03     |
|    ep_rew_mean      | -0.56    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.22     |
| time/               |          |
|    episodes         | 18000000 |
|    fps              | 1431     |
|    time_elapsed     | 15875    |
|    total_timesteps  | 22726316 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.37     |
|    n_updates        | 5669078  |
----------------------------------
Validating at 22800000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.02     |
|    ep_rew_mean      | -0.56    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.22     |
| time/               |          |
|    episodes         | 18100000 |
|    fps              | 1432     |
|    time_elapsed     | 15941    |
|    total_timesteps  | 22

Validating at 24100000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.58    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.21     |
| time/               |          |
|    episodes         | 19400000 |
|    fps              | 1435     |
|    time_elapsed     | 16842    |
|    total_timesteps  | 24185833 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.292    |
|    n_updates        | 6033958  |
----------------------------------
Validating at 24200000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.42    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.29     |
| time/               |          |
|    episodes         | 19500000 |
|    fps              | 1436     |
|    time_elapsed     | 16914    |
|    total_timesteps  | 24

Validating at 25600000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.56    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.22     |
| time/               |          |
|    episodes         | 20800000 |
|    fps              | 1433     |
|    time_elapsed     | 17890    |
|    total_timesteps  | 25644859 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.223    |
|    n_updates        | 6398714  |
----------------------------------
Validating at 25700000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.54    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.23     |
| time/               |          |
|    episodes         | 20900000 |
|    fps              | 1433     |
|    time_elapsed     | 17960    |
|    total_timesteps  | 25

Validating at 27100000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.62    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.19     |
| time/               |          |
|    episodes         | 22200000 |
|    fps              | 1435     |
|    time_elapsed     | 18880    |
|    total_timesteps  | 27105321 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.372    |
|    n_updates        | 6763830  |
----------------------------------
Validating at 27200000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.06     |
|    ep_rew_mean      | -0.64    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.18     |
| time/               |          |
|    episodes         | 22300000 |
|    fps              | 1435     |
|    time_elapsed     | 18951    |
|    total_timesteps  | 27

Validating at 28500000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.03     |
|    ep_rew_mean      | -0.66    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.17     |
| time/               |          |
|    episodes         | 23600000 |
|    fps              | 1434     |
|    time_elapsed     | 19919    |
|    total_timesteps  | 28564995 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.259    |
|    n_updates        | 7128748  |
----------------------------------
Validating at 28600000/30000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.38    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.31     |
| time/               |          |
|    episodes         | 23700000 |
|    fps              | 1433     |
|    time_elapsed     | 19994    |
|    total_timesteps  | 28

Validating at 30000000/30000000
