In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf
import random
import glob
import os
from multiprocessing import Process
from stable_baselines import DQN
import sys
sys.path.append('..')
from modules import utils, constants
from modules.masked_env import LupusEnv

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



  "stable-baselines is in maintenance mode, please use [Stable-Baselines3 (SB3)](https://github.com/DLR-RM/stable-baselines3) for an up-to-date version. You can find a [migration guide](https://stable-baselines3.readthedocs.io/en/master/guide/migration.html) in SB3 documentation."





In [2]:
SEED = constants.SEED
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
tf.set_random_seed(SEED)
tf.compat.v1.set_random_seed(SEED)
SEED

In [3]:
def create_env(X, y, random=True):
    env = LupusEnv(X, y, random)
    print(f'The environment seed is {env.seed()}') #to delete
    return env

In [13]:
def stable_dueling_dqn(X_train, y_train, timesteps, per=False):
    training_env = create_env(X_train, y_train)
    model = DQN('MlpPolicy', training_env, verbose=1, seed=constants.SEED, learning_rate=0.0001, buffer_size=1000000, learning_starts=50000, 
                train_freq=4, target_network_update_freq=100, exploration_final_eps=0.05, n_cpu_tf_sess=1, double_q=False, prioritized_replay=per)
    
    model.learn(total_timesteps=timesteps, log_interval=100000)
    training_env.close()
    return model

In [14]:
def run_dqn_model(X_train, y_train, steps):
    model = stable_dueling_dqn(X_train, y_train, steps)
    return model

In [15]:
df = pd.read_csv('../new_data/train_set_missingness_0.3.csv')
df = df.fillna(-1)
df.head()

Unnamed: 0,ana,fever,leukopenia,thrombocytopenia,auto_immune_hemolysis,delirium,psychosis,seizure,non_scarring_alopecia,oral_ulcers,...,proteinuria,biopsy_proven_lupus_nephritis,anti_cardioliphin_antibodies,anti_β2gp1_antibodies,lupus_anti_coagulant,low_c3,low_c4,anti_dsdna_antibody,anti_smith_antibody,label
0,1,1.0,-1.0,0.0,0.0,0.0,0.0,0.0,-1.0,0.0,...,-1.0,0.0,-1.0,0.0,0.0,0.0,1.0,0.0,0.0,1
1,1,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,-1.0,-1.0,-1.0,-1.0,0.0,1.0,-1.0,1
2,0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,-1.0,0.0,0.0,0.0,-1.0,0.0,0.0,0
3,0,0.0,-1.0,-1.0,-1.0,0.0,-1.0,0.0,0.0,0.0,...,1.0,0.0,-1.0,0.0,-1.0,0.0,0.0,-1.0,0.0,0
4,1,0.0,0.0,0.0,-1.0,1.0,-1.0,-1.0,1.0,0.0,...,1.0,3.0,0.0,0.0,-1.0,0.0,0.0,0.0,0.0,1


In [16]:
X = df.iloc[:, 0:-1]
Y = df.iloc[:, -1]
X, Y = np.array(X), np.array(Y)
X.shape, Y.shape

((50400, 24), (50400,))

In [17]:
feat_num = X.shape[1]
s = np.zeros((feat_num,), dtype=np.float32)
bin_mask = np.zeros((feat_num,), dtype=np.float32)
s

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [18]:
c = np.stack([s, bin_mask])
c.shape

(2, 24)

In [28]:
idx=1639
X[idx]

array([ 1., -1.,  0., -1.,  0.,  0., -1.,  0.,  1., -1., -1.,  0.,  0.,
        0.,  1.,  1.,  0., -1., -1.,  0.,  0.,  1.,  0.,  0.])

In [29]:
Y[idx]

1

In [None]:
Step 3 of index 1639
old state: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
action: 19
new state: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

In [27]:
run_dqn_model(X, Y, 100)

The environment seed is [42]
RESETTING!!!
Step 1 of index 41905
old state: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
action: 0
new state: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
info: {'index': 41905, 'episode_length': 1, 'reward': 1, 'y_pred': 0, 'y_actual': 0, 'trajectory': ['No lupus'], 'terminated': False, 'is_success': True}
RESETTING!!!
Step 1 of index 7296
old state: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
action: 16
new state: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
info: {'index': 7296, 'episode_length': 1, 'reward': 0, 'y_pred': nan, 'y_actua

<stable_baselines.deepq.dqn.DQN at 0x7f91ce0b1d50>