-
Notifications
You must be signed in to change notification settings - Fork 0
/
etdqn_train.py
47 lines (40 loc) · 1.44 KB
/
etdqn_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
''' Train ETDQN '''
import gym
import time
from etdqn import ETDQN
from params import *
env = gym.make('highway-v0')
config = {
'action': {'type': 'DiscreteMetaAction',},
'observation':
{
'type': 'Kinematics',
'vehicles_count': PARAM_VEH_OBS_COUNT,
'features': ['presence', 'x', 'y', 'vx', 'vy'],
},
'manual_control': False,
'simulation_frequency': PARAM_SF,
'policy_frequency': PARAM_PF,
'vehicles_count': PARAM_VEH_COUNT,
'vehicles_density': PARAM_VEH_DENSITY,
}
env.config.update(config)
env.reset()
ad_model = ETDQN( env,
memory_normal_size = PARAM_NORMAL_MEM_SIZE,
batch_normal_size = PARAM_NORMAL_BATCH_SIZE,
episode_max_steps = PARAM_MAX_STEPS_ONE_EP,
e_greedy_increment = PARAM_EGREEDY_INC,
e_greedy = PARAM_EGREEDY,
learning_rate = PARAM_LR,
reward_decay = PARAM_GAMMA,
target_update_interval = PARAM_TARGET_UPDATE_INTERVAL,
DOUBLE_DQN = DOUBLE_DQN,
DUELING_DQN = DUELING_DQN,
)
t0 = time.time()
ad_model.learn( learn_start = PARAM_LEARN_START,
max_time_steps = PARAM_MAX_STEPS,
interval_steps_save = PARAM_INTERVAL_TEST,
)
print('training time: ', time.time()-t0)