In [36]:
import pandas as pd
import numpy as np
import random
from sklearn.model_selection import train_test_split
import tensorflow
import os
import utils
from envs import SyntheticComplexHbEnv
from stable_baselines.common.env_checker import check_env
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import DQN
from stable_baselines import bench, logger
import matplotlib.pyplot as plt
%matplotlib inline

  "stable-baselines is in maintenance mode, please use [Stable-Baselines3 (SB3)](https://github.com/DLR-RM/stable-baselines3) for an up-to-date version. You can find a [migration guide](https://stable-baselines3.readthedocs.io/en/master/guide/migration.html) in SB3 documentation."


In [3]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tensorflow.set_random_seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)

#### Functions

In [16]:
def sample_train_set(x, y, sample_num):
    idx_list = random.sample(list(x.index), sample_num)
    print(idx_list)
    sampled_x = x.loc[idx_list]
    sampled_y = y.loc[idx_list]
    return np.array(sampled_x), np.array(sampled_y)

In [38]:
def stable_dqn(X_train, y_train, timesteps, model_file_name):
    training_env = SyntheticComplexHbEnv(X_train, y_train)
    env = bench.Monitor(training_env, logger.get_dir())
    model = DQN('MlpPolicy', training_env, verbose=1, seed=SEED, n_cpu_tf_sess=1)
    model.learn(total_timesteps=timesteps, log_interval=10000)
    model.save(f'models/{model_file_name}_{timesteps}.pkl')
    env.close()
    return model

In [37]:
def evaluate_dqn(dqn_model, X_test, y_test):
    test_df = pd.DataFrame()
    test_env = SyntheticComplexHbEnv(X_test, y_test, random=False)
    count=0

    try:
        while True:
            count+=1
            if count%5000==0:
                print(f'Count: {count}')
            obs, done = test_env.reset(), False
            while not done:
                action, _states = dqn_model.predict(obs, deterministic=True)
                obs, rew, done,info = test_env.step(action)
                if done == True:
                    test_df = test_df.append(info, ignore_index=True)
    except StopIteration:
        print('Testing done.....')
    return test_df

#### The Data

In [4]:
df = pd.read_csv('data/anemia_synth_dataset_hb_some_nans.csv') #my real dataset i think
df = df.fillna(0)
classes = list(df.label.unique())
nums = [i for i in range(len(classes))]
class_dict = dict(zip(classes, nums))
print(class_dict)
df['label'] = df['label'].replace(class_dict)
print(df.label.value_counts())
X = df.iloc[:, 0:-1]
y = df.iloc[:, -1]

full_X_train, X_test, full_y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=SEED)
X_test, y_test = np.array(X_test), np.array(y_test)
full_X_train.shape, X_test.shape, full_y_train.shape, y_test.shape

{'No anemia': 0,
 'Hemolytic anemia': 1,
 'Aplastic anemia': 2,
 'Iron deficiency anemia': 3,
 'Vitamin B12/Folate deficiency anemia': 4,
 'Anemia of chronic disease': 5}

In [None]:
train_sizes = [500, 1000, 3000, 5000, 10000]

In [None]:
for train_size in train_sizes[1]:
    X_train, y_train = sample_train_set(full_X_train, full_y_train, train_size)
    unique, counts = np.unique(y_train, return_counts=True)
    if len(unique) != 6:
        print(f'Skipping {train_size} because only {len(np.unique(y_train))} classes are in the sample')
    else:
        counts_dict = dict(zip(unique, counts))
        print(counts_dict)
        dqn_model = stable_dqn(X_train, y_train, int(2e3), f'train_sizes/{train_size}')
        test_df = evaluate_dqn(dqn_model, X_test, y_test)
        y_pred_df, success_df, success_rate = utils.get_success_rate(test_df)
        
        test_df.to_csv('test_dfs/train_sizes/test_df_2e6.csv', index=False)
        y_pred_df.to_csv('test_dfs/train_sizes/y_pred_df_2e6.csv', index=False)
        success_df.to_csv('test_dfs/train_sizes/success_df_2e6.csv', index=False)

1    14146
0    10000
2     9450
5     1869
4     1575
3     1343
Name: label, dtype: int64


((26868, 6), (11515, 6), (26868,), (11515,))

In [17]:
X_train, y_train = get_train_set(full_X_train, full_y_train, 10)
X_train

[28757, 23379, 36616, 36592, 31451, 30815, 25219, 32822, 37936, 7073]


Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv
28757,2.327392,159.123619,2.301563,0.0,0.0,86.538896
23379,10.018695,0.0,5.113819,0.0,0.0,83.41971
36616,2.650073,195.361559,4.727333,0.0,0.0,90.03499
36592,9.293322,86.760192,5.399623,0.0,0.0,84.800889
31451,2.026132,139.514259,0.343936,0.0,440.542479,81.700721
30815,14.958928,48.659838,0.0,0.0,0.0,89.468
25219,14.625554,47.883954,0.0,0.0,0.0,96.210813
32822,16.121179,193.918203,0.0,0.0,251.490067,82.681152
37936,8.213408,0.0,4.333178,0.0,357.561181,95.004433
7073,3.952061,0.0,4.363863,0.0,427.077906,98.404189


In [33]:
len(np.unique(y_train))

3

In [34]:
unique, counts = np.unique(y_train, return_counts=True)
unique, counts

(array([0, 1, 2], dtype=int64), array([3, 6, 1], dtype=int64))

In [35]:
counts_dict = dict(zip(unique, counts))
counts_dict

{0: 3, 1: 6, 2: 1}

In [31]:
df.loc[7073]

hemoglobin                 3.952061
ferritin                   0.000000
ret_count                  4.363863
segmented_neutrophils      0.000000
tibc                     427.077906
mcv                       98.404189
label                      1.000000
Name: 7073, dtype: float64