-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
206 lines (170 loc) · 5.18 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import tensorflow as tf
class Model:
"""The base class for operating a Reinforcement Learning deep net in
TensorFlow. All networks descend from this class
"""
def __init__(self):
# Input TF placeholders that must be set
self._obs_t_ph = None
self._act_t_ph = None
self._rew_t_ph = None
self._obs_tp1_ph = None
self._done_ph = None
# TF Ops that should be set
self._train_op = None
self._update_target = None # Optional
self._action = None # Optional
self.obs_dtype = None
self.obs_shape = None
self.act_dtype = None
self.act_shape = None
def build(self):
raise NotImplementedError()
def _build(self):
self._obs_t_ph = tf.placeholder(self.obs_dtype, [None] + self.obs_shape, name="obs_t_ph")
self._act_t_ph = tf.placeholder(self.act_dtype, [None] + self.act_shape, name="act_t_ph")
self._rew_t_ph = tf.placeholder(tf.float32, [None], name="rew_t_ph")
self._obs_tp1_ph = tf.placeholder(self.obs_dtype, [None] + self.obs_shape, name="obs_tp1_ph")
self._done_ph = tf.placeholder(tf.bool, [None], name="done_ph")
def restore(self, graph):
"""Restore the Variables, placeholders and Ops needed by the class so that
it can operate in exactly the same way as if `self.build()` was called
Args:
graph: tf.Graph. Graph, restored from a checkpoint
"""
# Get Ops
try:
self._train_op = graph.get_operation_by_name("train_op")
except KeyError:
pass
try:
self._update_target = graph.get_operation_by_name("update_target")
except KeyError:
pass
# Get Placeholders
try:
self._obs_t_ph = graph.get_tensor_by_name("obs_t_ph:0")
except KeyError:
pass
try:
self._act_t_ph = graph.get_tensor_by_name("act_t_ph:0")
except KeyError:
pass
try:
self._rew_t_ph = graph.get_tensor_by_name("rew_t_ph:0")
except KeyError:
pass
try:
self._obs_tp1_ph = graph.get_tensor_by_name("obs_tp1_ph:0")
except KeyError:
pass
try:
self._done_ph = graph.get_tensor_by_name("done_ph:0")
except KeyError:
pass
try:
self._action = graph.get_tensor_by_name("action:0")
except KeyError:
pass
self._restore(graph)
def initialize(self, sess):
"""Run additional initialization for the model when it was created via
self.build(). Assumes that tf.global_variables_initializer() and
tf.local_variables_initializer() have already been run
"""
raise NotImplementedError()
def control_action(self, sess, state):
"""Compute control action for the model. NOTE that this should NOT include
any exploration policy, but should only return the action that would be
performed if the model was being evaluated
Args:
sess: tf.Session(). Currently open session
state: np.array. Observation for the current state
Returns:
The calculated action. Type and shape varies based on the specific model
"""
raise NotImplementedError()
def _restore(self, graph):
raise NotImplementedError()
@property
def name(self):
"""
Returns:
name of the model class
"""
return self.__class__.__name__
@property
def train_op(self):
"""
Returns:
`tf.Op` that trains the network. Requires that `self.obs_t_ph`,
`self.act_t_ph`, `self.obs_tp1_ph`, `self.done_ph` placeholders
are set via feed_dict. Might require other placeholders as well.
"""
if self._train_op is not None:
return self._train_op
else:
raise NotImplementedError()
@property
def update_target(self):
"""
Returns:
`tf.Op` that updates the target network (if one is used).
"""
if self._update_target is not None:
return self._update_target
else:
raise NotImplementedError()
@property
def obs_t_ph(self):
"""
Returns:
`tf.placeholder` for observations at time t from the training batch
"""
if self._obs_t_ph is not None:
return self._obs_t_ph
else:
raise NotImplementedError()
@property
def act_t_ph(self):
"""
Returns:
`tf.placeholder` for actions at time t from the training batch
"""
if self._act_t_ph is not None:
return self._act_t_ph
else:
raise NotImplementedError()
@property
def rew_t_ph(self):
"""
Returns:
`tf.placeholder` for actions at time t from the training batch
"""
if self._rew_t_ph is not None:
return self._rew_t_ph
else:
raise NotImplementedError()
@property
def obs_tp1_ph(self):
"""
Returns:
`tf.placeholder` for observations at time t+1 from the training batch
"""
if self._obs_tp1_ph is not None:
return self._obs_tp1_ph
else:
raise NotImplementedError()
@property
def done_ph(self):
"""
Returns:
`tf.placeholder` to indicate end of episode for examples in the training batch
"""
if self._done_ph is not None:
return self._done_ph
else:
raise NotImplementedError()
# @property
# def training_ph(self):
# return self._training_ph