In [1]:
import pandas as pd
import numpy as np
import random
import os
import tensorflow
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import copy
from gym import Env
from gym.spaces import Discrete, Box
import d3rlpy
from d3rlpy.algos import DQN
from d3rlpy.online.buffers import ReplayBuffer
from d3rlpy.online.explorers import LinearDecayEpsilonGreedy
from envs import SyntheticSimpleEnv, SyntheticComplexEnv
import helper

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tensorflow.set_random_seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)

#### Original datasets

In [3]:
file_simple = 'data/dataset_10000.csv'
class_dict_simple = {'A':0, 'B':1, 'C':2}
X_train, X_test, y_train, y_test = helper.split_dataset(file_simple, class_dict_simple)

#### Online Training

In [4]:
simple_online_train_env = SyntheticSimpleEnv(X_train, y_train)
simple_online_test_env = SyntheticSimpleEnv(X_test, y_test)

In [6]:
online_dqn = DQN(batch_size=32, learning_rate=2.5e-4, target_update_interval=100, use_gpu=False)
buffer = ReplayBuffer(maxlen=50000, env=simple_online_train_env)
explorer = LinearDecayEpsilonGreedy(start_epsilon=1.0, end_epsilon=0.1, duration=10000)

In [7]:
online_dqn.fit_online(simple_online_train_env, buffer, explorer=explorer, eval_env=simple_online_train_env, n_steps=120000, 
                      n_steps_per_epoch=1000, update_interval=10) 

2022-07-29 19:05.12 [info     ] Directory is created at d3rlpy_logs\DQN_online_20220729190512
2022-07-29 19:05.12 [debug    ] Building model...
2022-07-29 19:05.12 [debug    ] Model has been built.
2022-07-29 19:05.12 [info     ] Parameters are saved to d3rlpy_logs\DQN_online_20220729190512\params.json params={'action_scaler': None, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 0.00025, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_update_interval': 100, 'use_gpu': None, 'algorithm': 'DQN', 'observation_shape': (3,), 'action_size': 6}


  0%|          | 0/120000 [00:00<?, ?it/s]

2022-07-29 19:05.13 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_1000.pt
2022-07-29 19:05.13 [info     ] DQN_online_20220729190512: epoch=1 step=1000 epoch=1 metrics={'time_inference': 0.0004429876804351807, 'time_environment_step': 1.3001441955566405e-05, 'time_step': 0.0007869985103607178, 'rollout_return': 0.26526315789473687, 'time_sample_batch': 0.00010305335841228053, 'time_algorithm_update': 0.00284526765961008, 'loss': 0.8220997296043278, 'evaluation': -2.6} step=1000
2022-07-29 19:05.14 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_2000.pt
2022-07-29 19:05.14 [info     ] DQN_online_20220729190512: epoch=2 step=2000 epoch=2 metrics={'time_inference': 0.0005050077438354492, 'time_environment_step': 8.996248245239257e-06, 'time_step': 0.0008910181522369385, 'rollout_return': 0.19196428571428573, 'time_sample_batch': 0.00014001846313476562, 'time_algorithm_update': 0.0032301592826843263, 'loss': 0.575

2022-07-29 19:05.28 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_17000.pt
2022-07-29 19:05.28 [info     ] DQN_online_20220729190512: epoch=17 step=17000 epoch=17 metrics={'time_inference': 0.0005289621353149414, 'time_environment_step': 1.1527538299560547e-05, 'time_step': 0.0009734196662902833, 'rollout_return': 2.1544401544401546, 'time_sample_batch': 0.00016013860702514647, 'time_algorithm_update': 0.0037034010887145997, 'loss': 0.27945538103580475, 'evaluation': 2.9} step=17000
2022-07-29 19:05.29 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_18000.pt
2022-07-29 19:05.29 [info     ] DQN_online_20220729190512: epoch=18 step=18000 epoch=18 metrics={'time_inference': 0.00044799613952636717, 'time_environment_step': 8.99958610534668e-06, 'time_step': 0.0008170092105865478, 'rollout_return': 2.289473684210526, 'time_sample_batch': 8.001804351806641e-05, 'time_algorithm_update': 0.0031302809715270997, 'loss'

2022-07-29 19:05.44 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_33000.pt
2022-07-29 19:05.44 [info     ] DQN_online_20220729190512: epoch=33 step=33000 epoch=33 metrics={'time_inference': 0.0005580070018768311, 'time_environment_step': 1.2996196746826171e-05, 'time_step': 0.001020998239517212, 'rollout_return': 2.7868217054263567, 'time_sample_batch': 0.00011000633239746093, 'time_algorithm_update': 0.003959884643554687, 'loss': 0.2284731788188219, 'evaluation': 3.4} step=33000
2022-07-29 19:05.45 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_34000.pt
2022-07-29 19:05.45 [info     ] DQN_online_20220729190512: epoch=34 step=34000 epoch=34 metrics={'time_inference': 0.0005558459758758545, 'time_environment_step': 2.2997379302978515e-05, 'rollout_return': 2.661596958174905, 'time_step': 0.0010211074352264404, 'time_sample_batch': 8.521556854248047e-05, 'time_algorithm_update': 0.0038907313346862793, 'loss': 

2022-07-29 19:06.00 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_49000.pt
2022-07-29 19:06.00 [info     ] DQN_online_20220729190512: epoch=49 step=49000 epoch=49 metrics={'time_inference': 0.0005920681953430175, 'time_environment_step': 1.8992900848388673e-05, 'time_step': 0.0011050002574920655, 'rollout_return': 2.936802973977695, 'time_sample_batch': 7.000207901000977e-05, 'time_algorithm_update': 0.004299731254577637, 'loss': 0.15778151974081994, 'evaluation': 3.3} step=49000
2022-07-29 19:06.02 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_50000.pt
2022-07-29 19:06.02 [info     ] DQN_online_20220729190512: epoch=50 step=50000 epoch=50 metrics={'time_inference': 0.0006319773197174072, 'time_environment_step': 1.299905776977539e-05, 'time_step': 0.001140967845916748, 'rollout_return': 2.7517985611510793, 'time_sample_batch': 0.00010004520416259766, 'time_algorithm_update': 0.004339709281921387, 'loss': 0

2022-07-29 19:06.35 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_65000.pt
2022-07-29 19:06.35 [info     ] DQN_online_20220729190512: epoch=65 step=65000 epoch=65 metrics={'time_inference': 0.0010744876861572266, 'time_environment_step': 2.063560485839844e-05, 'rollout_return': 2.923636363636364, 'time_step': 0.001987217664718628, 'time_sample_batch': 0.0002056121826171875, 'time_algorithm_update': 0.007682294845581055, 'loss': 0.13588536409661173, 'evaluation': 3.6} step=65000
2022-07-29 19:06.37 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_66000.pt
2022-07-29 19:06.37 [info     ] DQN_online_20220729190512: epoch=66 step=66000 epoch=66 metrics={'time_inference': 0.0012471253871917725, 'time_environment_step': 3.488659858703613e-05, 'time_step': 0.002303144931793213, 'rollout_return': 2.88212927756654, 'time_sample_batch': 0.00027138471603393554, 'time_algorithm_update': 0.008803274631500244, 'loss': 0.121

2022-07-29 19:07.13 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_81000.pt
2022-07-29 19:07.13 [info     ] DQN_online_20220729190512: epoch=81 step=81000 epoch=81 metrics={'time_inference': 0.0012333672046661377, 'time_environment_step': 3.1943082809448245e-05, 'rollout_return': 3.0858208955223883, 'time_step': 0.0022273242473602297, 'time_sample_batch': 0.0002346062660217285, 'time_algorithm_update': 0.008483984470367432, 'loss': 0.11449691937305033, 'evaluation': 3.6} step=81000
2022-07-29 19:07.16 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_82000.pt
2022-07-29 19:07.16 [info     ] DQN_online_20220729190512: epoch=82 step=82000 epoch=82 metrics={'time_inference': 0.0013758890628814697, 'time_environment_step': 3.577995300292969e-05, 'time_step': 0.002539657115936279, 'rollout_return': 2.92619926199262, 'time_sample_batch': 0.00027652740478515623, 'time_algorithm_update': 0.009724953174591065, 'loss': 0.

2022-07-29 19:07.53 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_97000.pt
2022-07-29 19:07.53 [info     ] DQN_online_20220729190512: epoch=97 step=97000 epoch=97 metrics={'time_inference': 0.0014285097122192383, 'time_environment_step': 4.713153839111328e-05, 'rollout_return': 3.071969696969697, 'time_step': 0.0026534504890441896, 'time_sample_batch': 0.00028429031372070315, 'time_algorithm_update': 0.01008375883102417, 'loss': 0.11200985274277628, 'evaluation': 3.4} step=97000
2022-07-29 19:07.55 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_98000.pt
2022-07-29 19:07.55 [info     ] DQN_online_20220729190512: epoch=98 step=98000 epoch=98 metrics={'time_inference': 0.0012408924102783204, 'time_environment_step': 3.900504112243652e-05, 'time_step': 0.0023322315216064453, 'rollout_return': 3.18, 'time_sample_batch': 0.00031463623046875, 'time_algorithm_update': 0.009060571193695069, 'loss': 0.1128297182172536

2022-07-29 19:08.33 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_113000.pt
2022-07-29 19:08.33 [info     ] DQN_online_20220729190512: epoch=113 step=113000 epoch=113 metrics={'time_inference': 0.0015215396881103516, 'time_environment_step': 4.035019874572754e-05, 'time_step': 0.002821653604507446, 'rollout_return': 2.951310861423221, 'time_sample_batch': 0.0003615951538085938, 'time_algorithm_update': 0.01077650785446167, 'loss': 0.11042534437030554, 'evaluation': 3.0} step=113000
2022-07-29 19:08.35 [info     ] Model parameters are saved to d3rlpy_logs\DQN_online_20220729190512\model_114000.pt
2022-07-29 19:08.35 [info     ] DQN_online_20220729190512: epoch=114 step=114000 epoch=114 metrics={'time_inference': 0.0010175437927246094, 'time_environment_step': 2.8738021850585938e-05, 'time_step': 0.001891339063644409, 'rollout_return': 3.0076045627376424, 'time_sample_batch': 0.0002674603462219238, 'time_algorithm_update': 0.007377159595489502, 'lo

In [None]:
online_test_df = helper.test_d3rlpy_dqn(online_dqn, simple_online_test_env)

In [None]:
len(X_test), len(online_test_df)

In [None]:
y_pred_df = test_df[online_test_df['y_pred'].notna()]
success_df = y_pred_df[y_pred_df['y_pred']== y_pred_df['y_actual']]
len(success_df)

In [None]:
success_rate = len(success_df)/len(online_test_df)*100
success_rate

In [None]:
#avg length and return 
avg_length, avg_return = helper.get_avg_length_reward(online_test_df)
avg_length, avg_return

In [None]:
acc, f1, roc_auc = test(y_pred_df['y_actual'], y_pred_df['y_pred'])
acc, f1, roc_auc

#### Offline

In [None]:
offline_dataset = MDPDataset.load('data/dqn_simple_dataset.h5')
simple_offline_train_env = SyntheticSimpleEnv(X_train, y_train)
simple_offline_test_env = SyntheticSimpleEnv(X_test, y_test)

In [None]:
offline_dqn = d3rlpy.algos.DQN()
offline_dqn.fit(offline_dataset, eval_episodes=offline_dataset.episodes, n_steps=120000, n_steps_per_epoch=10000, 
        scorers={"environment": d3rlpy.metrics.evaluate_on_environment(simple_offline_train_env),
                },
       )

#### Creating datasets

In [None]:
# 
# complex_env = SyntheticComplexEnv(X_train_comp, y_train_comp)

In [None]:
# random_simple_dataset = helper.create_d3rlpy_dataset('random', simple_env, 'data/random_simple_dataset.h5')
# dqn_simple_dataset = helper.create_d3rlpy_dataset('dqn', simple_env, 'data/dqn_simple_dataset.h5')

In [None]:
# dqn_complex_dataset = helper.create_dataset('dqn', complex_env, 'data/dqn_complex_dataset.h5')
# random_complex_dataset = helper.create_dataset('random', complex_env, 'data/random_complex_dataset.h5')

#### Training the dqns

In [None]:
# simple_random_dqn = helper.train_dqn('data/random_simple_dataset.h5', simple_env, 'models/random_simple_dqn')

In [None]:
simple_dqn_dqn = train_dqn('data/dqn_simple_dataset.h5', simple_env, 'models/dqn_simple_dqn')

In [None]:
#complex_dqn_dqn = train_dqn('data/dqn_complex_dataset.h5', simple_env, 'models/dqn_complex_dqn')

In [None]:
#complex_random_dqn = train_dqn('data/random_complex_dataset.h5', simple_env, 'models/random_complex_dqn')

#### Testing the trained dqns

In [None]:
testing_simple_env = SyntheticSimpleEnv(X_test_simp[:5], y_test_simp[:5], random=False)
simple_random_test_df = test_dqn(simple_random_dqn, testing_simple_env)
simple_random_test_df

In [None]:
X_test[:5]