-
Notifications
You must be signed in to change notification settings - Fork 6
/
natureqn.py
executable file
·63 lines (49 loc) · 2.2 KB
/
natureqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import tensorflow as tf
import tensorflow.contrib.layers as layers
from utils.general import get_logger
from utils.test_env import EnvTest
from schedule import LinearExploration, LinearSchedule
from linear import Linear
from config import testconfig_teacher as config
class NatureQN(Linear):
"""
Implementing DeepMind's Nature paper. Here are the relevant urls.
https://storage.googleapis.com/deepmind-data/assets/papers/DeepMindNature14236Paper.pdf
https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
"""
def get_q_values_op(self, state, scope, reuse=False):
"""
Returns Q values for all actions
Args:
state: (tf tensor)
shape = (batch_size, img height, img width, nchannels)
scope: (string) scope name, that specifies if target network or not
reuse: (bool) reuse of variables in the scope
Returns:
out: (tf tensor) of shape = (batch_size, num_actions)
"""
num_actions = self.env.action_space.n
out = state
# compress the student network
size1, size2, size3, size4 = (16, 16, 16, 128) if self.student else (32, 64, 64, 512)
with tf.variable_scope(scope, reuse=reuse):
conv1 = layers.conv3d(inputs=out, num_outputs=size1, kernel_size=[8,8], stride=4) #20
conv2 = layers.conv3d(inputs=conv1, num_outputs=size2, kernel_size=[4,4], stride=2) #10
conv3 = layers.conv3d(inputs=conv2, num_outputs=size3, kernel_size=[3,3], stride=1) #10
hidden = layers.fully_connected(layers.flatten(conv3), size4)
out = layers.fully_connected(hidden, num_actions, activation_fn=None)
return out
"""
Use deep Q network for test environment.
"""
if __name__ == '__main__':
env = EnvTest((80, 80, 1))
# exploration strategy
exp_schedule = LinearExploration(env, config.eps_begin,
config.eps_end, config.eps_nsteps)
# learning rate schedule
lr_schedule = LinearSchedule(config.lr_begin, config.lr_end,
config.lr_nsteps)
# train model
model = NatureQN(env, config)
model.run(exp_schedule, lr_schedule)