In [1]:
import pandas as pd
import numpy as np
import random
import os
import torch
import sys
sys.path.append('../')
from modules import utils, constants
import warnings
warnings.filterwarnings('ignore')

In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)

#### The data

In [3]:
# train_df = pd.read_csv('../data/25_jan/train_set_basic.csv')
train_df = pd.read_csv('../data/missingness/0/training_set.csv')
train_df.head()

Unnamed: 0,ana,fever,leukopenia,thrombocytopenia,auto_immune_hemolysis,delirium,psychosis,seizure,non_scarring_alopecia,oral_ulcers,...,joint_involvement,proteinuria,anti_cardioliphin_antibodies,anti_β2gp1_antibodies,lupus_anti_coagulant,low_c3,low_c4,anti_dsdna_antibody,anti_smith_antibody,label
0,0,1,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,1,0,0,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,1,0,1
3,1,0,0,0,0,0,0,1,1,0,...,0,1,0,0,0,0,0,1,1,1
4,1,0,0,1,1,1,0,0,0,0,...,0,1,0,0,0,1,0,0,1,1


In [4]:
train_df.cutaneous_lupus.value_counts()

0    44200
3     3120
1     2246
2      834
Name: cutaneous_lupus, dtype: int64

In [5]:
train_df[(train_df.ana==0) & (train_df.label==1)]

Unnamed: 0,ana,fever,leukopenia,thrombocytopenia,auto_immune_hemolysis,delirium,psychosis,seizure,non_scarring_alopecia,oral_ulcers,...,joint_involvement,proteinuria,anti_cardioliphin_antibodies,anti_β2gp1_antibodies,lupus_anti_coagulant,low_c3,low_c4,anti_dsdna_antibody,anti_smith_antibody,label


In [6]:
train_df.isna().sum()

ana                             0
fever                           0
leukopenia                      0
thrombocytopenia                0
auto_immune_hemolysis           0
delirium                        0
psychosis                       0
seizure                         0
non_scarring_alopecia           0
oral_ulcers                     0
cutaneous_lupus                 0
pleural_effusion                0
pericardial_effusion            0
acute_pericarditis              0
joint_involvement               0
proteinuria                     0
anti_cardioliphin_antibodies    0
anti_β2gp1_antibodies           0
lupus_anti_coagulant            0
low_c3                          0
low_c4                          0
anti_dsdna_antibody             0
anti_smith_antibody             0
label                           0
dtype: int64

In [7]:
len(train_df)

50400

In [8]:
X_train = train_df.iloc[:, 0:-1]
y_train = train_df.iloc[:, -1]
X_train, y_train = np.array(X_train), np.array(y_train)
X_train.shape, y_train.shape

((50400, 23), (50400,))

In [9]:
action_list = list(constants.CLASS_DICT.keys()) + [col  for col in train_df.columns if col!='label']
action_list

['No lupus',
 'Lupus',
 'Inconclusive diagnosis',
 'ana',
 'fever',
 'leukopenia',
 'thrombocytopenia',
 'auto_immune_hemolysis',
 'delirium',
 'psychosis',
 'seizure',
 'non_scarring_alopecia',
 'oral_ulcers',
 'cutaneous_lupus',
 'pleural_effusion',
 'pericardial_effusion',
 'acute_pericarditis',
 'joint_involvement',
 'proteinuria',
 'anti_cardioliphin_antibodies',
 'anti_β2gp1_antibodies',
 'lupus_anti_coagulant',
 'low_c3',
 'low_c4',
 'anti_dsdna_antibody',
 'anti_smith_antibody']

In [10]:
# steps = 11279933
# dqn_model = utils.stable_dqn3(X_train, y_train, steps, True, f'../models/24_feb/dqn_basic_{steps}')

In [11]:
#0.92 at 11,279,933 timesteps and 1900000 episodes also at 36,144,785 timestepsc(7800000 episodes)

In [12]:
# for steps in [int(5e4), int(1e5), int(3e5), int(5e5), int(8e5), int(1e6), int(2e6)]:    
# # for steps in [int(1.1e5)]:  
#     dqn_model = utils.stable_dqn3(X_train, y_train, steps, True, f'../models/26_jan/dqn_basic_{steps}')

#### The Model

In [13]:
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import CheckpointCallback
from modules.env import LupusEnv

In [14]:
training_env = LupusEnv(X_train, y_train)
training_env = DummyVecEnv([lambda: training_env])

model = A2C('MlpPolicy', training_env, verbose=1, seed=constants.SEED)
checkpoint_callback = CheckpointCallback(save_freq=100000, save_path='../models/logs/a2c', name_prefix='a2c_basic')
model.learn(total_timesteps=70000000, log_interval=100000, callback=checkpoint_callback)

# Save the trained DQN agent
model.save('../models/22_mar/a2c_lupus_diagnosis')
training_env.close()

Using cpu device
-------------------------------------
| time/                 |           |
|    fps                | 2138      |
|    iterations         | 100000    |
|    time_elapsed       | 233       |
|    total_timesteps    | 500000    |
| train/                |           |
|    entropy_loss       | -8.88e-05 |
|    explained_variance | -1.19e-07 |
|    learning_rate      | 0.0007    |
|    n_updates          | 99999     |
|    policy_loss        | -1.4e-06  |
|    value_loss         | 1.02      |
-------------------------------------
-------------------------------------
| time/                 |           |
|    fps                | 1909      |
|    iterations         | 200000    |
|    time_elapsed       | 523       |
|    total_timesteps    | 1000000   |
| train/                |           |
|    entropy_loss       | -2.85e-05 |
|    explained_variance | 5.96e-08  |
|    learning_rate      | 0.0007    |
|    n_updates          | 199999    |
|    policy_loss        | 3.43e-0

-------------------------------------
| time/                 |           |
|    fps                | 2005      |
|    iterations         | 1700000   |
|    time_elapsed       | 4238      |
|    total_timesteps    | 8500000   |
| train/                |           |
|    entropy_loss       | -9.89e-06 |
|    explained_variance | nan       |
|    learning_rate      | 0.0007    |
|    n_updates          | 1699999   |
|    policy_loss        | -0        |
|    value_loss         | 0.824     |
-------------------------------------
-------------------------------------
| time/                 |           |
|    fps                | 2014      |
|    iterations         | 1800000   |
|    time_elapsed       | 4466      |
|    total_timesteps    | 9000000   |
| train/                |           |
|    entropy_loss       | -9.85e-06 |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 1799999   |
|    policy_loss        | -0        |
|    value_l

-------------------------------------
| time/                 |           |
|    fps                | 2089      |
|    iterations         | 3300000   |
|    time_elapsed       | 7894      |
|    total_timesteps    | 16500000  |
| train/                |           |
|    entropy_loss       | -3.67e-06 |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 3299999   |
|    policy_loss        | -0        |
|    value_loss         | 0.969     |
-------------------------------------
-------------------------------------
| time/                 |           |
|    fps                | 2092      |
|    iterations         | 3400000   |
|    time_elapsed       | 8123      |
|    total_timesteps    | 17000000  |
| train/                |           |
|    entropy_loss       | -3.66e-06 |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 3399999   |
|    policy_loss        | -0        |
|    value_l

------------------------------------
| time/                 |          |
|    fps                | 2120     |
|    iterations         | 4900000  |
|    time_elapsed       | 11554    |
|    total_timesteps    | 24500000 |
| train/                |          |
|    entropy_loss       | -3.6e-06 |
|    explained_variance | 0        |
|    learning_rate      | 0.0007   |
|    n_updates          | 4899999  |
|    policy_loss        | -0       |
|    value_loss         | 1.03     |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 2121     |
|    iterations         | 5000000  |
|    time_elapsed       | 11782    |
|    total_timesteps    | 25000000 |
| train/                |          |
|    entropy_loss       | -3.6e-06 |
|    explained_variance | 0        |
|    learning_rate      | 0.0007   |
|    n_updates          | 4999999  |
|    policy_loss        | -0       |
|    value_loss         | 1.06     |
-

-------------------------------------
| time/                 |           |
|    fps                | 2136      |
|    iterations         | 6500000   |
|    time_elapsed       | 15211     |
|    total_timesteps    | 32500000  |
| train/                |           |
|    entropy_loss       | -3.54e-06 |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 6499999   |
|    policy_loss        | -0        |
|    value_loss         | 1.04      |
-------------------------------------
-------------------------------------
| time/                 |           |
|    fps                | 2137      |
|    iterations         | 6600000   |
|    time_elapsed       | 15440     |
|    total_timesteps    | 33000000  |
| train/                |           |
|    entropy_loss       | -3.54e-06 |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 6599999   |
|    policy_loss        | -0        |
|    value_l

------------------------------------
| time/                 |          |
|    fps                | 2146     |
|    iterations         | 8100000  |
|    time_elapsed       | 18869    |
|    total_timesteps    | 40500000 |
| train/                |          |
|    entropy_loss       | -3.5e-06 |
|    explained_variance | 0        |
|    learning_rate      | 0.0007   |
|    n_updates          | 8099999  |
|    policy_loss        | -0       |
|    value_loss         | 0.996    |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 2146     |
|    iterations         | 8200000  |
|    time_elapsed       | 19097    |
|    total_timesteps    | 41000000 |
| train/                |          |
|    entropy_loss       | -3.5e-06 |
|    explained_variance | nan      |
|    learning_rate      | 0.0007   |
|    n_updates          | 8199999  |
|    policy_loss        | -0       |
|    value_loss         | 0.863    |
-

-------------------------------------
| time/                 |           |
|    fps                | 2152      |
|    iterations         | 9700000   |
|    time_elapsed       | 22527     |
|    total_timesteps    | 48500000  |
| train/                |           |
|    entropy_loss       | -3.48e-06 |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 9699999   |
|    policy_loss        | -0        |
|    value_loss         | 1.02      |
-------------------------------------
-------------------------------------
| time/                 |           |
|    fps                | 2153      |
|    iterations         | 9800000   |
|    time_elapsed       | 22756     |
|    total_timesteps    | 49000000  |
| train/                |           |
|    entropy_loss       | -3.48e-06 |
|    explained_variance | 1.79e-07  |
|    learning_rate      | 0.0007    |
|    n_updates          | 9799999   |
|    policy_loss        | -0        |
|    value_l

-------------------------------------
| time/                 |           |
|    fps                | 2157      |
|    iterations         | 11300000  |
|    time_elapsed       | 26190     |
|    total_timesteps    | 56500000  |
| train/                |           |
|    entropy_loss       | -3.48e-06 |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 11299999  |
|    policy_loss        | -0        |
|    value_loss         | 0.977     |
-------------------------------------
-------------------------------------
| time/                 |           |
|    fps                | 2157      |
|    iterations         | 11400000  |
|    time_elapsed       | 26419     |
|    total_timesteps    | 57000000  |
| train/                |           |
|    entropy_loss       | -3.48e-06 |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 11399999  |
|    policy_loss        | -0        |
|    value_l

-------------------------------------
| time/                 |           |
|    fps                | 2160      |
|    iterations         | 12900000  |
|    time_elapsed       | 29856     |
|    total_timesteps    | 64500000  |
| train/                |           |
|    entropy_loss       | -3.49e-06 |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 12899999  |
|    policy_loss        | -0        |
|    value_loss         | 1.04      |
-------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 2160     |
|    iterations         | 13000000 |
|    time_elapsed       | 30085    |
|    total_timesteps    | 65000000 |
| train/                |          |
|    entropy_loss       | -3.5e-06 |
|    explained_variance | 0        |
|    learning_rate      | 0.0007   |
|    n_updates          | 12999999 |
|    policy_loss        | -0       |
|    value_loss         