In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import random
import os
import torch
import sys
sys.path.append('../..')
from modules.many_features import utils
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline
#from scripts.many_features.env import SynthenticEnv

  import pandas.util.testing as tm


In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)

#### The Data

In [3]:
df = pd.read_csv('../../data/anemia_synth_dataset_20_feats.csv')
df = df.fillna(-1)
df = df.rename({'label':'old_label', 'new_label': 'label'}, axis=1)
df = df.drop(['old_label'], axis=1)
df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,serum_iron,tibc,rbc,mcv,age,gender,indirect_bilirubin,transferrin,creatinine,cholestrol,copper,ethanol,folate,glucose,label
0,8.346669,34.608675,5.174318,1.0,178.823599,544.800971,4.925291,79.809413,19.453101,0,0.280924,165.29241,1.021167,143.166495,81.798487,63.004682,15.459905,42.754792,Iron deficiency anemia
1,6.846112,51.940156,1.932334,0.0,194.86117,461.075389,3.867615,91.181611,64.570588,0,1.035061,308.503918,0.463956,114.80362,81.114191,61.861251,27.576398,103.542063,Aplastic anemia
2,10.255567,27.599586,-1.0,-1.0,-1.0,214.483741,5.422049,75.707741,87.400889,1,1.671696,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,Iron deficiency anemia
3,15.223738,340.517667,-1.0,1.0,-1.0,424.972955,3.14241,93.390707,31.718546,1,2.121362,-1.0,1.617016,-1.0,-1.0,6.641431,-1.0,-1.0,No anemia
4,6.168014,380.230626,-1.0,1.0,197.169333,678.776865,4.317302,105.998157,64.037952,1,1.096339,204.885061,0.794662,-1.0,52.751477,50.990429,-1.0,55.228045,Vitamin B12/Folate deficiency anemia


In [4]:
df.label.value_counts()

Anemia of chronic disease               5055
Vitamin B12/Folate deficiency anemia    2226
Iron deficiency anemia                  1591
Aplastic anemia                         1420
Hemolytic anemia                        1397
No anemia                               1341
Unspecified anemia                      1000
Name: label, dtype: int64

In [5]:
class_dict = {'No anemia': 0, 'Vitamin B12/Folate deficiency anemia': 1, 'Unspecified anemia': 2, 
              'Anemia of chronic disease': 3, 'Iron deficiency anemia': 4, 'Hemolytic anemia': 5, 'Aplastic anemia': 6}
df['label'] = df['label'].replace(class_dict)
X = df.iloc[:, 0:-1]
y = df.iloc[:, -1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=SEED)
X_train, y_train = np.array(X_train), np.array(y_train)
X_test, y_test = np.array(X_test), np.array(y_test)
X_train.shape, X_test.shape, y_train.shape, y_test.shape

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


((9821, 18), (4209, 18), (9821,), (4209,))

In [6]:
action_list = list(class_dict.keys()) + [col  for col in df.columns if col!='label']
action_list

['No anemia',
 'Vitamin B12/Folate deficiency anemia',
 'Unspecified anemia',
 'Anemia of chronic disease',
 'Iron deficiency anemia',
 'Hemolytic anemia',
 'Aplastic anemia',
 'hemoglobin',
 'ferritin',
 'ret_count',
 'segmented_neutrophils',
 'serum_iron',
 'tibc',
 'rbc',
 'mcv',
 'age',
 'gender',
 'indirect_bilirubin',
 'transferrin',
 'creatinine',
 'cholestrol',
 'copper',
 'ethanol',
 'folate',
 'glucose']

In [7]:
len(action_list)

25

#### Training 

In [19]:
%%time
timesteps = int(10e6)
dqn_model = utils.stable_dqn(X_train, y_train, timesteps, True, f'../../models/many_features/stable_dqn_{timesteps}')
test_df = utils.evaluate_dqn(dqn_model, X_test, y_test)
test_df.head()

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5.03     |
|    ep_rew_mean      | -5.62    |
|    exploration_rate | 0.614    |
|    success_rate     | 0.14     |
| time/               |          |
|    episodes         | 100000   |
|    fps              | 1305     |
|    time_elapsed     | 311      |
|    total_timesteps  | 406424   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 5.97     |
|    n_updates        | 89105    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.38     |
|    ep_rew_mean      | -0.69    |
|    exploration_rate | 0.298    |
|    success_rate     | 0.38     |
| time/               |          |
|    episodes         | 200000   |
|    fps              | 1158     |
|    time_elapsed     | 637    

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.05     |
|    ep_rew_mean      | -0.35    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.35     |
| time/               |          |
|    episodes         | 1600000  |
|    fps              | 1038     |
|    time_elapsed     | 2163     |
|    total_timesteps  | 2247154  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.445    |
|    n_updates        | 549288   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.03     |
|    ep_rew_mean      | -0.37    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.33     |
| time/               |          |
|    episodes         | 1700000  |
|    fps              | 1036     |
|    time_elapsed     | 2267     |
|    total_timesteps  | 2350907  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.03     |
|    ep_rew_mean      | -0.49    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.27     |
| time/               |          |
|    episodes         | 3100000  |
|    fps              | 1025     |
|    time_elapsed     | 3708     |
|    total_timesteps  | 3803314  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.553    |
|    n_updates        | 938328   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.05     |
|    ep_rew_mean      | -0.31    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.37     |
| time/               |          |
|    episodes         | 3200000  |
|    fps              | 1024     |
|    time_elapsed     | 3812     |
|    total_timesteps  | 3907067  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.05     |
|    ep_rew_mean      | -0.23    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.41     |
| time/               |          |
|    episodes         | 4600000  |
|    fps              | 851      |
|    time_elapsed     | 6295     |
|    total_timesteps  | 5359913  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.387    |
|    n_updates        | 1327478  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.03     |
|    ep_rew_mean      | -0.29    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.37     |
| time/               |          |
|    episodes         | 4700000  |
|    fps              | 822      |
|    time_elapsed     | 6639     |
|    total_timesteps  | 5463694  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.04     |
|    ep_rew_mean      | -0.2     |
|    exploration_rate | 0.05     |
|    success_rate     | 0.42     |
| time/               |          |
|    episodes         | 6100000  |
|    fps              | 604      |
|    time_elapsed     | 11450    |
|    total_timesteps  | 6916442  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.363    |
|    n_updates        | 1716610  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.05     |
|    ep_rew_mean      | -0.33    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.36     |
| time/               |          |
|    episodes         | 6200000  |
|    fps              | 595      |
|    time_elapsed     | 11793    |
|    total_timesteps  | 7020265  |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.07     |
|    ep_rew_mean      | -0.51    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.28     |
| time/               |          |
|    episodes         | 7600000  |
|    fps              | 510      |
|    time_elapsed     | 16600    |
|    total_timesteps  | 8472630  |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.417    |
|    n_updates        | 2105657  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.02     |
|    ep_rew_mean      | -0.36    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.33     |
| time/               |          |
|    episodes         | 7700000  |
|    fps              | 506      |
|    time_elapsed     | 16940    |
|    total_timesteps  | 8576383  |
| train/              |          |
|    learning_rate  

Count: 4209
Testing done.....
Wall time: 10h 34min 47s


Unnamed: 0,episode_length,index,is_success,reward,terminated,trajectory,y_actual,y_pred
0,1.0,0.0,0.0,-1.0,0.0,[Anemia of chronic disease],0.0,3.0
1,1.0,1.0,0.0,-1.0,0.0,[Anemia of chronic disease],1.0,3.0
2,1.0,2.0,0.0,-1.0,0.0,[Anemia of chronic disease],2.0,3.0
3,1.0,3.0,1.0,1.0,0.0,[Anemia of chronic disease],3.0,3.0
4,1.0,4.0,0.0,-1.0,0.0,[Anemia of chronic disease],0.0,3.0


#### Testing

In [15]:
# model = DQN.load(f'../../models/many_features/stable_dqn_{timesteps}', env=training_env)

In [24]:
success_rate, success_df = utils.success_rate(test_df)
success_rate

36.04181515799477

In [25]:
avg_length, avg_return = utils.get_avg_length_reward(test_df)
avg_length, avg_return

(1.0, -0.27916369684010456)

In [26]:
acc, f1, roc_auc = utils.test(test_df['y_actual'], test_df['y_pred'])
acc, f1, roc_auc

(0.3604181515799477, 0.07569482560750461, 0.5)

In [27]:
test_df.y_pred.unique()

array([3.])

#### Saving files

In [14]:
# test_df.to_csv(f'test_dfs/many_features')
# success_df.to_csv(f'test_dfs/many_futures')