In [1]:
import pandas as pd
import numpy as np
import random
import os
import torch
import sys
sys.path.append('../')
from lillian_modules import utils, simple_constants
from lillian_modules.simple_env import SimpleEnv
from stable_baselines3.dqn import DQN
# from modules.env import LupusEnv
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

In [2]:
SEED = simple_constants.SEED
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)

In [3]:
class_dict = simple_constants.CLASS_DICT

#### The data

In [4]:
df = pd.read_csv('../../RL-Agent-Diagnosis/lupus/data/very_simple_datasets/feats_8.csv')
df.head()

Unnamed: 0,ana,anti_dsdna_antibody,joint_involvement,proteinuria,pericardial_effusion,non_scarring_alopecia,leukopenia,delirium,label
0,0,0,1,0,1,1,1,0,No lupus
1,1,0,0,0,0,0,0,0,No lupus
2,0,0,1,0,0,0,0,1,No lupus
3,0,1,0,1,0,0,0,0,No lupus
4,1,0,1,0,1,1,1,0,Lupus


In [5]:
df['label'] = df['label'].replace(class_dict)
# print(df.label.value_counts())
X = df.iloc[:, 0:-1]
y = df.iloc[:, -1]

X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.20, stratify=y, random_state=SEED)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.10, stratify=y_train_val, random_state=SEED)
training_df = pd.concat([X_train, y_train], axis=1).reset_index(drop=True)
validation_df = pd.concat([X_val, y_val], axis=1).reset_index(drop=True)
testing_df = pd.concat([X_test, y_test], axis=1).reset_index(drop=True)
X_train, y_train = np.array(X_train), np.array(y_train)
X_val, y_val = np.array(X_val), np.array(y_val)
X_test, y_test = np.array(X_test), np.array(y_test)
X_train.shape, X_val.shape, X_test.shape, y_train.shape, y_val.shape, y_test.shape

((50400, 8), (5600, 8), (14000, 8), (50400,), (5600,), (14000,))

In [6]:
training_df.label.value_counts()

0    30854
1    19546
Name: label, dtype: int64

#### Training the model

In [7]:
def stable_dqn3(X_train, y_train, timesteps, validation_times=10, save=False, filename=None):
    training_env = SimpleEnv(X_train, y_train)
    print(f'ACTION SPACE: {training_env.actions}')
    model = DQN(policy='MlpPolicy', env=training_env, verbose=1, seed=SEED, X_val=X_val, y_val=y_val, gamma=0.90,
               to_save_folder='../models/very_simple_models/best_models/negative_reward_0.90_gamma/')
    model.learn(total_timesteps=timesteps, validation_times=validation_times, log_interval=100000)
    if save:
        model.save(filename)
    training_env.close()
    return model

In [8]:
ft_num = 8
steps = 5000000
# for steps in [int(6e6)]:
# for steps in [int()]:
model = stable_dqn3(X_train, y_train, timesteps=steps, validation_times=100)
# val_df.to_csv(f'../test_dfs/val_df_ft_{ft_num}_{steps}', index=False)
#     stable_dqn3(X_train, y_train, steps)

ACTION SPACE: ['No lupus', 'Lupus', 'Inconclusive diagnosis', 'ana', 'anti_dsdna_antibody', 'joint_involvement', 'proteinuria', 'pericardial_effusion', 'non_scarring_alopecia', 'leukopenia', 'delirium']
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
setting up model
Validating at 50000/5000000
Validating at 100000/5000000
Validating at 150000/5000000
Validating at 200000/5000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.8      |
|    ep_rew_mean      | 1.91     |
|    exploration_rate | 0.601    |
|    success_rate     | 0.42     |
| time/               |          |
|    episodes         | 100000   |
|    fps              | 1977     |
|    time_elapsed     | 106      |
|    total_timesteps  | 209978   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 3.89     |
|    n_updates        | 39994    |
----------------------------------
Validating 

Validating at 1550000/5000000
Validating at 1600000/5000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.02     |
|    ep_rew_mean      | 6.62     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.61     |
| time/               |          |
|    episodes         | 1400000  |
|    fps              | 1279     |
|    time_elapsed     | 1263     |
|    total_timesteps  | 1616188  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 7.97     |
|    n_updates        | 391546   |
----------------------------------
Validating at 1650000/5000000
Validating at 1700000/5000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1        |
|    ep_rew_mean      | 5.44     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.56     |
| time/               |          |
|    episodes         | 1500000  |
|    fps              | 1285     |
|    

Validating at 3000000/5000000
Validating at 3050000/5000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.06     |
|    ep_rew_mean      | 4.42     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.52     |
| time/               |          |
|    episodes         | 2800000  |
|    fps              | 1213     |
|    time_elapsed     | 2528     |
|    total_timesteps  | 3067853  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 10.1     |
|    n_updates        | 754463   |
----------------------------------
Validating at 3100000/5000000
Validating at 3150000/5000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.01     |
|    ep_rew_mean      | 6.63     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.61     |
| time/               |          |
|    episodes         | 2900000  |
|    fps              | 1208     |
|    

Validating at 4450000/5000000
Validating at 4500000/5000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.01     |
|    ep_rew_mean      | 7.35     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.64     |
| time/               |          |
|    episodes         | 4200000  |
|    fps              | 1246     |
|    time_elapsed     | 3627     |
|    total_timesteps  | 4520487  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 8.7      |
|    n_updates        | 1117621  |
----------------------------------
Validating at 4550000/5000000
Validating at 4600000/5000000
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.05     |
|    ep_rew_mean      | 7.07     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.63     |
| time/               |          |
|    episodes         | 4300000  |
|    fps              | 1236     |
|    