In [1]:
import pandas as pd
import numpy as np
# import tensorflow as tf
import random
import glob
import os
from multiprocessing import Process
import sys
sys.path.append('..')
from modules import constants
from modules.env import LupusEnv
import torch
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import CheckpointCallback

In [2]:
SEED = constants.SEED
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)
SEED

42

In [3]:
constants.BETA

9

In [4]:
def create_env(X, y, random=True):
    '''
    Creates and environment using the given data
    '''
    env = LupusEnv(X, y, random)
    print(f'The environment seed is {env.seed()}') #to delete
    return env

In [5]:
def stable_baselines3_dqn(X_train, y_train, steps, save, log_path, log_prefix, filename):
    training_env = create_env(X_train, y_train)
    model = DQN('MlpPolicy', training_env, verbose=1, seed=constants.SEED)
    checkpoint_callback = CheckpointCallback(save_freq=constants.CHECKPOINT_FREQ, save_path=log_path, 
                                             name_prefix=log_prefix)
    model.learn(total_timesteps=steps, log_interval=100000, callback=checkpoint_callback)
    if save:
#         model.save(f'{log_path}/{filename}.pkl')
        model.save(f'{log_path}/{filename}_full_model')
    training_env.close()
    return model

In [6]:
def run_dqn_model(model_type, steps):
    dir_name = f'seed_{SEED}_{steps}'
    parent_dir = f'../models/logs/{model_type}/missingness/0.1/biopsy_9/sb3'
    path = os.path.join(parent_dir, dir_name)
    os.mkdir(path)
    model = stable_baselines3_dqn(X_train, y_train, steps, save=True, log_path=path, log_prefix=model_type, 
                                  filename=f'{model_type}_{steps}')
    return model

In [7]:
train_df = pd.read_csv('../new_data/train_set_missingness_0.1.csv')
train_df = train_df.fillna(-1)
train_df.head()

Unnamed: 0,ana,fever,leukopenia,thrombocytopenia,auto_immune_hemolysis,delirium,psychosis,seizure,non_scarring_alopecia,oral_ulcers,...,proteinuria,biopsy_proven_lupus_nephritis,anti_cardioliphin_antibodies,anti_β2gp1_antibodies,lupus_anti_coagulant,low_c3,low_c4,anti_dsdna_antibody,anti_smith_antibody,label
0,1,1.0,0.0,0.0,-1.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1
1,1,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1
2,0,0.0,0.0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0
3,0,0.0,0.0,0.0,0.0,-1.0,0.0,0.0,-1.0,0.0,...,1.0,0.0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0
4,1,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,...,1.0,3.0,0.0,-1.0,0.0,0.0,-1.0,0.0,0.0,1


In [8]:
train_df.label.value_counts()

0    25240
1    25160
Name: label, dtype: int64

In [9]:
X_train = train_df.iloc[:, 0:-1]
y_train = train_df.iloc[:, -1]
X_train, y_train = np.array(X_train), np.array(y_train)
X_train.shape, y_train.shape

((50400, 24), (50400,))

In [10]:
# model_names = ['dqn', 'ddqn', 'dueling_dqn', 'dueling_ddqn', 'dqn_per', 'ddqn_per', 'dueling_dqn_per', 
#                'dueling_ddqn_per']
model_names = ['dueling_dqn_per', 'dueling_ddqn_per']
procs = []
steps = int(10e7)

In [11]:
for name in model_names:
#     run_dqn_model(name, steps)
    proc = Process(target=run_dqn_model, args=(name, steps))
    procs.append(proc)
    proc.start()

The environment seed is [42]
The environment seed is [42]
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5.3      |
|    ep_rew_mean      | -0.787   |
|    exploration_rate | 0.954    |
|    success_rate     | 0.14     |
| time/               |          |
|    episodes         | 100000   |
|    fps              | 2346     |
|    time_elapsed     | 205      |
|    total_timesteps  | 481446   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0827   |
|    n_updates        | 107861   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5.3      |
|    ep_rew_mean      | -0.787   |
|    exploration_rate | 0.954    |
|    success_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5.89     |
|    ep_rew_mean      | -0.93    |
|    exploration_rate | 0.597    |
|    success_rate     | 0.08     |
| time/               |          |
|    episodes         | 800000   |
|    fps              | 2071     |
|    time_elapsed     | 2049     |
|    total_timesteps  | 4244684  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 51       |
|    n_updates        | 1048670  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5.65     |
|    ep_rew_mean      | -0.749   |
|    exploration_rate | 0.539    |
|    success_rate     | 0.16     |
| time/               |          |
|    episodes         | 900000   |
|    fps              | 2029     |
|    time_elapsed     | 2390     |
|    total_timesteps  | 4850457  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 8.73     |
|    ep_rew_mean      | -0.962   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.06     |
| time/               |          |
|    episodes         | 1600000  |
|    fps              | 1861     |
|    time_elapsed     | 5386     |
|    total_timesteps  | 10025133 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 7.23     |
|    n_updates        | 2493783  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 8.73     |
|    ep_rew_mean      | -0.962   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.06     |
| time/               |          |
|    episodes         | 1600000  |
|    fps              | 1862     |
|    time_elapsed     | 5384     |
|    total_timesteps  | 10025133 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.98     |
|    ep_rew_mean      | -0.656   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.22     |
| time/               |          |
|    episodes         | 2300000  |
|    fps              | 1864     |
|    time_elapsed     | 8293     |
|    total_timesteps  | 15461668 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.65     |
|    n_updates        | 3852916  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.31     |
|    ep_rew_mean      | -0.382   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.35     |
| time/               |          |
|    episodes         | 2400000  |
|    fps              | 1866     |
|    time_elapsed     | 8624     |
|    total_timesteps  | 16099417 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.02     |
|    ep_rew_mean      | -0.438   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.31     |
| time/               |          |
|    episodes         | 3100000  |
|    fps              | 1851     |
|    time_elapsed     | 10945    |
|    total_timesteps  | 20265093 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.564    |
|    n_updates        | 5053773  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.02     |
|    ep_rew_mean      | -0.438   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.31     |
| time/               |          |
|    episodes         | 3100000  |
|    fps              | 1850     |
|    time_elapsed     | 10950    |
|    total_timesteps  | 20265093 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.01     |
|    ep_rew_mean      | -0.514   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.28     |
| time/               |          |
|    episodes         | 3800000  |
|    fps              | 1826     |
|    time_elapsed     | 13372    |
|    total_timesteps  | 24428953 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.517    |
|    n_updates        | 6094738  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5.82     |
|    ep_rew_mean      | -0.501   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.28     |
| time/               |          |
|    episodes         | 3900000  |
|    fps              | 1825     |
|    time_elapsed     | 13710    |
|    total_timesteps  | 25030566 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.8      |
|    ep_rew_mean      | -0.541   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.26     |
| time/               |          |
|    episodes         | 4600000  |
|    fps              | 1809     |
|    time_elapsed     | 16294    |
|    total_timesteps  | 29476836 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.719    |
|    n_updates        | 7356708  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.8      |
|    ep_rew_mean      | -0.541   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.26     |
| time/               |          |
|    episodes         | 4600000  |
|    fps              | 1807     |
|    time_elapsed     | 16312    |
|    total_timesteps  | 29476836 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.71     |
|    ep_rew_mean      | -0.615   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.22     |
| time/               |          |
|    episodes         | 5300000  |
|    fps              | 1794     |
|    time_elapsed     | 19076    |
|    total_timesteps  | 34240810 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.06     |
|    n_updates        | 8547702  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.36     |
|    ep_rew_mean      | -0.446   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.31     |
| time/               |          |
|    episodes         | 5400000  |
|    fps              | 1796     |
|    time_elapsed     | 19454    |
|    total_timesteps  | 34943500 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 8.57     |
|    ep_rew_mean      | -0.897   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.09     |
| time/               |          |
|    episodes         | 6100000  |
|    fps              | 1789     |
|    time_elapsed     | 22321    |
|    total_timesteps  | 39934228 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.739    |
|    n_updates        | 9971056  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 8.57     |
|    ep_rew_mean      | -0.897   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.09     |
| time/               |          |
|    episodes         | 6100000  |
|    fps              | 1786     |
|    time_elapsed     | 22353    |
|    total_timesteps  | 39934228 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.95     |
|    ep_rew_mean      | -0.707   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.19     |
| time/               |          |
|    episodes         | 6800000  |
|    fps              | 1778     |
|    time_elapsed     | 25599    |
|    total_timesteps  | 45534030 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.655    |
|    n_updates        | 11371007 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.94     |
|    ep_rew_mean      | -0.767   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.15     |
| time/               |          |
|    episodes         | 6900000  |
|    fps              | 1779     |
|    time_elapsed     | 26028    |
|    total_timesteps  | 46324986 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 8.24     |
|    ep_rew_mean      | -0.654   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.21     |
| time/               |          |
|    episodes         | 7600000  |
|    fps              | 1773     |
|    time_elapsed     | 29285    |
|    total_timesteps  | 51929845 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.346    |
|    n_updates        | 12969961 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 8.24     |
|    ep_rew_mean      | -0.654   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.21     |
| time/               |          |
|    episodes         | 7600000  |
|    fps              | 1771     |
|    time_elapsed     | 29319    |
|    total_timesteps  | 51929845 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.95     |
|    ep_rew_mean      | -0.627   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.22     |
| time/               |          |
|    episodes         | 8300000  |
|    fps              | 1766     |
|    time_elapsed     | 32610    |
|    total_timesteps  | 57590647 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.757    |
|    n_updates        | 14385161 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.77     |
|    ep_rew_mean      | -0.704   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.18     |
| time/               |          |
|    episodes         | 8400000  |
|    fps              | 1767     |
|    time_elapsed     | 33023    |
|    total_timesteps  | 58366306 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 8.4      |
|    ep_rew_mean      | -0.63    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.22     |
| time/               |          |
|    episodes         | 9100000  |
|    fps              | 1762     |
|    time_elapsed     | 36303    |
|    total_timesteps  | 64001122 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.365    |
|    n_updates        | 15987780 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 8.4      |
|    ep_rew_mean      | -0.63    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.22     |
| time/               |          |
|    episodes         | 9100000  |
|    fps              | 1760     |
|    time_elapsed     | 36345    |
|    total_timesteps  | 64001122 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.26     |
|    ep_rew_mean      | -0.675   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.2      |
| time/               |          |
|    episodes         | 9800000  |
|    fps              | 1756     |
|    time_elapsed     | 39653    |
|    total_timesteps  | 69637974 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.541    |
|    n_updates        | 17396993 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 8.06     |
|    ep_rew_mean      | -0.608   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.23     |
| time/               |          |
|    episodes         | 9900000  |
|    fps              | 1757     |
|    time_elapsed     | 40061    |
|    total_timesteps  | 70419803 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.37     |
|    ep_rew_mean      | -0.282   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.39     |
| time/               |          |
|    episodes         | 10600000 |
|    fps              | 1755     |
|    time_elapsed     | 43149    |
|    total_timesteps  | 75751518 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0823   |
|    n_updates        | 18925379 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.37     |
|    ep_rew_mean      | -0.282   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.39     |
| time/               |          |
|    episodes         | 10600000 |
|    fps              | 1753     |
|    time_elapsed     | 43201    |
|    total_timesteps  | 75751518 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.8      |
|    ep_rew_mean      | -0.437   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.34     |
| time/               |          |
|    episodes         | 11300000 |
|    fps              | 1753     |
|    time_elapsed     | 46033    |
|    total_timesteps  | 80736025 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.221    |
|    n_updates        | 20171506 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.4      |
|    ep_rew_mean      | -0.528   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.3      |
| time/               |          |
|    episodes         | 11400000 |
|    fps              | 1756     |
|    time_elapsed     | 46356    |
|    total_timesteps  | 81420055 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.24     |
|    ep_rew_mean      | 0.0542   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.57     |
| time/               |          |
|    episodes         | 12100000 |
|    fps              | 1759     |
|    time_elapsed     | 48793    |
|    total_timesteps  | 85843935 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.098    |
|    n_updates        | 21448483 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.24     |
|    ep_rew_mean      | 0.0542   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.57     |
| time/               |          |
|    episodes         | 12100000 |
|    fps              | 1757     |
|    time_elapsed     | 48853    |
|    total_timesteps  | 85843935 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5.53     |
|    ep_rew_mean      | 0.192    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.62     |
| time/               |          |
|    episodes         | 12800000 |
|    fps              | 1759     |
|    time_elapsed     | 50958    |
|    total_timesteps  | 89643205 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.123    |
|    n_updates        | 22398301 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.59     |
|    ep_rew_mean      | 0.245    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.64     |
| time/               |          |
|    episodes         | 12900000 |
|    fps              | 1761     |
|    time_elapsed     | 51159    |
|    total_timesteps  | 90126113 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.87     |
|    ep_rew_mean      | 0.229    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.63     |
| time/               |          |
|    episodes         | 13600000 |
|    fps              | 1763     |
|    time_elapsed     | 53019    |
|    total_timesteps  | 93484659 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.038    |
|    n_updates        | 23358664 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.87     |
|    ep_rew_mean      | 0.229    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.63     |
| time/               |          |
|    episodes         | 13600000 |
|    fps              | 1760     |
|    time_elapsed     | 53088    |
|    total_timesteps  | 93484659 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.19     |
|    ep_rew_mean      | 0.208    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.62     |
| time/               |          |
|    episodes         | 14300000 |
|    fps              | 1762     |
|    time_elapsed     | 54874    |
|    total_timesteps  | 96700328 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0726   |
|    n_updates        | 24162581 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.51     |
|    ep_rew_mean      | 0.164    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.6      |
| time/               |          |
|    episodes         | 14400000 |
|    fps              | 1764     |
|    time_elapsed     | 55049    |
|    total_timesteps  | 97148361 |
| train/              |          |
|    learning_rate  

In [12]:
for proc in procs:
    proc.join()
print('All jobs completed and terminated successfully')

All jobs completed and terminated successfully
