In [1]:
import pandas as pd
import numpy as np
import random
import os
import torch
import sys
sys.path.append('../')
from modules import utils, constants
import warnings
warnings.filterwarnings('ignore')

In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)

#### The data

In [3]:
train_df = pd.read_csv('../data/train_set_basic_balanced.csv')
train_df.head()

Unnamed: 0,ana,fever,leukopenia,thrombocytopenia,auto_immune_hemolysis,delirium,psychosis,seizure,non_scarring_alopecia,oral_ulcers,...,proteinuria,renal_biopsy_class,anti_cardioliphin_antibodies,anti_b2gp1_antibodies,lupus_anti_coagulant,c3,c4,anti_dsdna_antibody,anti_smith_antibody,label
0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1
1,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,2.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0
3,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1
4,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,2.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1


In [4]:
train_df.isna().sum()

ana                             0
fever                           0
leukopenia                      0
thrombocytopenia                0
auto_immune_hemolysis           0
delirium                        0
psychosis                       0
seizure                         0
non_scarring_alopecia           0
oral_ulcers                     0
subacute_cutaneous              0
discoid_lupus                   0
acute_cutaneous_lupus           0
pleural_effusion                0
pericardial_effusion            0
acute_pericarditis              0
joint_involvement               0
proteinuria                     0
renal_biopsy_class              0
anti_cardioliphin_antibodies    0
anti_b2gp1_antibodies           0
lupus_anti_coagulant            0
c3                              0
c4                              0
anti_dsdna_antibody             0
anti_smith_antibody             0
label                           0
dtype: int64

In [5]:
len(train_df)

56000

In [6]:
X_train = train_df.iloc[:, 0:-1]
y_train = train_df.iloc[:, -1]
X_train, y_train = np.array(X_train), np.array(y_train)
X_train.shape, y_train.shape

((56000, 26), (56000,))

In [7]:
action_list = list(constants.CLASS_DICT.keys()) + [col  for col in train_df.columns if col!='label']
action_list

['No lupus',
 'Lupus',
 'Inconclusive diagnosis',
 'ana',
 'fever',
 'leukopenia',
 'thrombocytopenia',
 'auto_immune_hemolysis',
 'delirium',
 'psychosis',
 'seizure',
 'non_scarring_alopecia',
 'oral_ulcers',
 'subacute_cutaneous',
 'discoid_lupus',
 'acute_cutaneous_lupus',
 'pleural_effusion',
 'pericardial_effusion',
 'acute_pericarditis',
 'joint_involvement',
 'proteinuria',
 'renal_biopsy_class',
 'anti_cardioliphin_antibodies',
 'anti_b2gp1_antibodies',
 'lupus_anti_coagulant',
 'c3',
 'c4',
 'anti_dsdna_antibody',
 'anti_smith_antibody']

In [12]:
for steps in [int(25e6), int(30e6)]:
    dqn_model = utils.stable_dqn3(X_train, y_train, steps, True, f'../models/dqn_balanced_{steps}')

using stable baselines 3
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5.59     |
|    ep_rew_mean      | -0.72    |
|    exploration_rate | 0.804    |
|    success_rate     | 0.14     |
| time/               |          |
|    episodes         | 100000   |
|    fps              | 1073     |
|    time_elapsed     | 480      |
|    total_timesteps  | 515778   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0707   |
|    n_updates        | 116444   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.19     |
|    ep_rew_mean      | -0.78    |
|    exploration_rate | 0.58     |
|    success_rate     | 0.11     |
| time/               |          |
|    episodes         | 200000   |
|    fps              | 967      |
|    t

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.4      |
|    ep_rew_mean      | 0.32     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.66     |
| time/               |          |
|    episodes         | 1600000  |
|    fps              | 957      |
|    time_elapsed     | 10909    |
|    total_timesteps  | 10448316 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0564   |
|    n_updates        | 2599578  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.05     |
|    ep_rew_mean      | 0.46     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.73     |
| time/               |          |
|    episodes         | 1700000  |
|    fps              | 958      |
|    time_elapsed     | 11605    |
|    total_timesteps  | 11123430 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.6      |
|    ep_rew_mean      | 0.42     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.71     |
| time/               |          |
|    episodes         | 3100000  |
|    fps              | 985      |
|    time_elapsed     | 21534    |
|    total_timesteps  | 21214553 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0588   |
|    n_updates        | 5291138  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.5      |
|    ep_rew_mean      | 0.52     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.76     |
| time/               |          |
|    episodes         | 3200000  |
|    fps              | 992      |
|    time_elapsed     | 22138    |
|    total_timesteps  | 21972104 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.9      |
|    ep_rew_mean      | 0.26     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.63     |
| time/               |          |
|    episodes         | 1100000  |
|    fps              | 1294     |
|    time_elapsed     | 5455     |
|    total_timesteps  | 7063629  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.11     |
|    n_updates        | 1753407  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.84     |
|    ep_rew_mean      | 0.32     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.66     |
| time/               |          |
|    episodes         | 1200000  |
|    fps              | 1288     |
|    time_elapsed     | 6002     |
|    total_timesteps  | 7734797  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.37     |
|    ep_rew_mean      | 0.4      |
|    exploration_rate | 0.05     |
|    success_rate     | 0.7      |
| time/               |          |
|    episodes         | 2600000  |
|    fps              | 1257     |
|    time_elapsed     | 13945    |
|    total_timesteps  | 17540602 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0876   |
|    n_updates        | 4372650  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.17     |
|    ep_rew_mean      | 0.34     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.67     |
| time/               |          |
|    episodes         | 2700000  |
|    fps              | 1250     |
|    time_elapsed     | 14608    |
|    total_timesteps  | 18261931 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.61     |
|    ep_rew_mean      | 0.44     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.72     |
| time/               |          |
|    episodes         | 4100000  |
|    fps              | 1215     |
|    time_elapsed     | 23515    |
|    total_timesteps  | 28579529 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00569  |
|    n_updates        | 7132382  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.1      |
|    ep_rew_mean      | 0.38     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.69     |
| time/               |          |
|    episodes         | 4200000  |
|    fps              | 1209     |
|    time_elapsed     | 24255    |
|    total_timesteps  | 29333231 |
| train/              |          |
|    learning_rate  