In [41]:
import random
import d3rlpy
import numpy as np
import pandas as pd
import tensorflow
import os
from d3rlpy.algos import DQN
from d3rlpy.dataset import MDPDataset
from sklearn.model_selection import train_test_split
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from d3rlpy.metrics.scorer import evaluate_on_environment
from envs import SyntheticSimpleEnv, SyntheticComplexEnv

In [42]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tensorflow.set_random_seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)

In [43]:
df_simple = pd.read_csv('data/dataset_10000.csv')
class_dict_simple = {'A':0, 'B':1, 'C':2}
df_simple['label'] = df_simple['label'].replace(class_dict_simple)
X_simple = df_simple.iloc[:, 0:-1]
y_simple = df_simple.iloc[:, -1]
X_simple, y_simple = np.array(X_simple), np.array(y_simple)

In [44]:
df_complex = pd.read_csv('data/anemia_synth_dataset.csv')
df_complex = df_complex.fillna(0)
classes = list(df_complex.label.unique())
nums = [i for i in range(len(classes))]
class_dict_complex = dict(zip(classes, nums))
df_complex['label'] = df_complex['label'].replace(class_dict_complex)
X_complex = df_complex.iloc[:, 0:-1]
y_complex = df_complex.iloc[:, -1]
X_complex, y_complex = np.array(X_complex), np.array(y_complex)

In [45]:
simple_env = SyntheticSimpleEnv(X_simple, y_simple)
complex_env = SyntheticComplexEnv(X_complex, y_complex)

#### Random dataset

In [20]:
dataset = MDPDataset.load('data/random_simple_dataset.h5')

In [21]:
len(dataset.episodes)

50242

In [49]:
def train_dqn(dataset, env, model_name):
    dataset = MDPDataset.load(dataset)
    train_episodes, test_episodes = train_test_split(dataset, test_size=0.3, random_state=SEED)
    dqn = DQN(use_gpu = False)
    dqn.build_with_dataset(dataset)
    td_error = td_error_scorer(dqn, test_episodes)
    evaluate_scorer = evaluate_on_environment(env)
    rewards = evaluate_scorer(dqn)
    dqn.fit(train_episodes,
            eval_episodes=test_episodes,
            n_epochs=10,
            scorers={
                'td_error': td_error_scorer,
                'value_scale': average_value_estimation_scorer,
                'environment': evaluate_scorer
            })
    dqn.save_model(f'{model_name}.pt')
    return dqn

In [50]:
simple_random_dqn = train_dqn('data/random_simple_dataset.h5', simple_env)

2022-07-29 17:02.24 [debug    ] RoundIterator is selected.
2022-07-29 17:02.24 [info     ] Directory is created at d3rlpy_logs\DQN_20220729170224
2022-07-29 17:02.24 [info     ] Parameters are saved to d3rlpy_logs\DQN_20220729170224\params.json params={'action_scaler': None, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 6.25e-05, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_update_interval': 8000, 'use_gpu': None, 'algorithm': 'DQN', 'observation_shape': (3,), 'action_size': 6}


Epoch 1/10:   0%|          | 0/2190 [00:00<?, ?it/s]

2022-07-29 17:03.08 [info     ] DQN_20220729170224: epoch=1 step=2190 epoch=1 metrics={'time_sample_batch': 0.00012073451525544467, 'time_algorithm_update': 0.003488982975755108, 'loss': 0.33465556674488056, 'time_step': 0.0037048583705675655, 'td_error': 1.5565589000518754, 'value_scale': 2.2305529996314526, 'environment': -0.2} step=2190
2022-07-29 17:03.08 [info     ] Model parameters are saved to d3rlpy_logs\DQN_20220729170224\model_2190.pt


Epoch 2/10:   0%|          | 0/2190 [00:00<?, ?it/s]

2022-07-29 17:03.57 [info     ] DQN_20220729170224: epoch=2 step=4380 epoch=2 metrics={'time_sample_batch': 0.0001150887850757059, 'time_algorithm_update': 0.0035611700249589197, 'loss': 0.2667383168692186, 'time_step': 0.0037607726441126438, 'td_error': 1.5310171544096254, 'value_scale': 2.2278952830401986, 'environment': -1.0} step=4380
2022-07-29 17:03.57 [info     ] Model parameters are saved to d3rlpy_logs\DQN_20220729170224\model_4380.pt


Epoch 3/10:   0%|          | 0/2190 [00:00<?, ?it/s]

2022-07-29 17:04.44 [info     ] DQN_20220729170224: epoch=3 step=6570 epoch=3 metrics={'time_sample_batch': 9.680480173189346e-05, 'time_algorithm_update': 0.0035394613056966705, 'loss': 0.25470457389678586, 'time_step': 0.003724697300288231, 'td_error': 1.444722966252833, 'value_scale': 2.2445662006178644, 'environment': 0.0} step=6570
2022-07-29 17:04.44 [info     ] Model parameters are saved to d3rlpy_logs\DQN_20220729170224\model_6570.pt


Epoch 4/10:   0%|          | 0/2190 [00:00<?, ?it/s]

2022-07-29 17:05.32 [info     ] DQN_20220729170224: epoch=4 step=8760 epoch=4 metrics={'time_sample_batch': 0.00011194721204505119, 'time_algorithm_update': 0.0035849449297064516, 'loss': 0.25112347174520905, 'time_step': 0.0037773973865596125, 'td_error': 0.6159708031527409, 'value_scale': 2.6878994670018814, 'environment': 2.4} step=8760
2022-07-29 17:05.32 [info     ] Model parameters are saved to d3rlpy_logs\DQN_20220729170224\model_8760.pt


Epoch 5/10:   0%|          | 0/2190 [00:00<?, ?it/s]

2022-07-29 17:06.20 [info     ] DQN_20220729170224: epoch=5 step=10950 epoch=5 metrics={'time_sample_batch': 9.741412994524116e-05, 'time_algorithm_update': 0.0036812642937925854, 'loss': 0.2293070062960936, 'time_step': 0.003878346970092216, 'td_error': 0.7284105694805405, 'value_scale': 2.808781335344405, 'environment': 0.4} step=10950
2022-07-29 17:06.20 [info     ] Model parameters are saved to d3rlpy_logs\DQN_20220729170224\model_10950.pt


Epoch 6/10:   0%|          | 0/2190 [00:00<?, ?it/s]

2022-07-29 17:07.07 [info     ] DQN_20220729170224: epoch=6 step=13140 epoch=6 metrics={'time_sample_batch': 0.00010601516183652834, 'time_algorithm_update': 0.00364995198707058, 'loss': 0.22698394043124429, 'time_step': 0.0038447806824287866, 'td_error': 0.6653001610627708, 'value_scale': 2.7427495893282514, 'environment': 1.0} step=13140
2022-07-29 17:07.07 [info     ] Model parameters are saved to d3rlpy_logs\DQN_20220729170224\model_13140.pt


Epoch 7/10:   0%|          | 0/2190 [00:00<?, ?it/s]

2022-07-29 17:07.54 [info     ] DQN_20220729170224: epoch=7 step=15330 epoch=7 metrics={'time_sample_batch': 0.00010072357578364681, 'time_algorithm_update': 0.003601349543218743, 'loss': 0.22453893180511314, 'time_step': 0.003787745297227276, 'td_error': 0.8311620545555137, 'value_scale': 2.8374592755538486, 'environment': 1.1} step=15330
2022-07-29 17:07.54 [info     ] Model parameters are saved to d3rlpy_logs\DQN_20220729170224\model_15330.pt


Epoch 8/10:   0%|          | 0/2190 [00:00<?, ?it/s]

2022-07-29 17:08.40 [info     ] DQN_20220729170224: epoch=8 step=17520 epoch=8 metrics={'time_sample_batch': 9.891714679596086e-05, 'time_algorithm_update': 0.0036265472298887768, 'loss': 0.21909881134657827, 'time_step': 0.0038120181593176436, 'td_error': 0.5217435184555583, 'value_scale': 3.01786881357415, 'environment': 2.0} step=17520
2022-07-29 17:08.40 [info     ] Model parameters are saved to d3rlpy_logs\DQN_20220729170224\model_17520.pt


Epoch 9/10:   0%|          | 0/2190 [00:00<?, ?it/s]

2022-07-29 17:09.28 [info     ] DQN_20220729170224: epoch=9 step=19710 epoch=9 metrics={'time_sample_batch': 9.191852726348459e-05, 'time_algorithm_update': 0.0036418330179501888, 'loss': 0.21166769268169794, 'time_step': 0.0038241975383671452, 'td_error': 0.5121731461590434, 'value_scale': 3.075703610174872, 'environment': 2.8} step=19710
2022-07-29 17:09.28 [info     ] Model parameters are saved to d3rlpy_logs\DQN_20220729170224\model_19710.pt


Epoch 10/10:   0%|          | 0/2190 [00:00<?, ?it/s]

2022-07-29 17:10.16 [info     ] DQN_20220729170224: epoch=10 step=21900 epoch=10 metrics={'time_sample_batch': 0.00011400425270812152, 'time_algorithm_update': 0.003809371277621892, 'loss': 0.2101646257623962, 'time_step': 0.00401546606734463, 'td_error': 0.4999135833799437, 'value_scale': 3.06569347494352, 'environment': 1.1} step=21900
2022-07-29 17:10.16 [info     ] Model parameters are saved to d3rlpy_logs\DQN_20220729170224\model_21900.pt


In [None]:
def synthetic_dqn_eval(dqn_model):
    attempts, correct = 0,0
    test_df = pd.DataFrame()

    env = SyntheticSimpleEnv(X_test, y_test, random=False)
    count=0

    try:
        while True:
            count+=1
            if count%5000==0:
                print(f'Count: {count}')
            obs, done = env.reset(), False
            while not done:
                action = dqn_model.predict([obs])[0]
                obs, rew, done,info = env.step(action)
                #if (done==True) & (np.isfinite(info['y_pred'])):
                if done == True:
                    test_df = test_df.append(info, ignore_index=True)
                #print('....................TEST DF ....................')
                #if len(test_df) != 0:
                #    print(test_df.head())

    except StopIteration:
        print('Testing done.....')
    return test_df

test_df = synthetic_dqn_eval(dqn_model)