-
Notifications
You must be signed in to change notification settings - Fork 484
/
tf_util.py
332 lines (264 loc) · 11.5 KB
/
tf_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import collections
import numpy as np
import os
import tensorflow as tf
def sum(x, axis=None, keepdims=False):
return tf.reduce_sum(x, axis=None if axis is None else [axis], keep_dims = keepdims)
def mean(x, axis=None, keepdims=False):
return tf.reduce_mean(x, axis=None if axis is None else [axis], keep_dims = keepdims)
def var(x, axis=None, keepdims=False):
meanx = mean(x, axis=axis, keepdims=keepdims)
return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims)
def std(x, axis=None, keepdims=False):
return tf.sqrt(var(x, axis=axis, keepdims=keepdims))
def max(x, axis=None, keepdims=False):
return tf.reduce_max(x, axis=None if axis is None else [axis], keep_dims = keepdims)
def min(x, axis=None, keepdims=False):
return tf.reduce_min(x, axis=None if axis is None else [axis], keep_dims = keepdims)
def concatenate(arrs, axis=0):
return tf.concat(axis=axis, values=arrs)
def argmax(x, axis=None):
return tf.argmax(x, axis=axis)
def softmax(x, axis=None):
return tf.nn.softmax(x, axis=axis)
# ================================================================
# Misc
# ================================================================
def is_placeholder(x):
return type(x) is tf.Tensor and len(x.op.inputs) == 0
# ================================================================
# Inputs
# ================================================================
class TfInput(object):
def __init__(self, name="(unnamed)"):
"""Generalized Tensorflow placeholder. The main differences are:
- possibly uses multiple placeholders internally and returns multiple values
- can apply light postprocessing to the value feed to placeholder.
"""
self.name = name
def get(self):
"""Return the tf variable(s) representing the possibly postprocessed value
of placeholder(s).
"""
raise NotImplemented()
def make_feed_dict(data):
"""Given data input it to the placeholder(s)."""
raise NotImplemented()
class PlacholderTfInput(TfInput):
def __init__(self, placeholder):
"""Wrapper for regular tensorflow placeholder."""
super().__init__(placeholder.name)
self._placeholder = placeholder
def get(self):
return self._placeholder
def make_feed_dict(self, data):
return {self._placeholder: data}
class BatchInput(PlacholderTfInput):
def __init__(self, shape, dtype=tf.float32, name=None):
"""Creates a placeholder for a batch of tensors of a given shape and dtype
Parameters
----------
shape: [int]
shape of a single elemenet of the batch
dtype: tf.dtype
number representation used for tensor contents
name: str
name of the underlying placeholder
"""
super().__init__(tf.placeholder(dtype, [None] + list(shape), name=name))
class Uint8Input(PlacholderTfInput):
def __init__(self, shape, name=None):
"""Takes input in uint8 format which is cast to float32 and divided by 255
before passing it to the model.
On GPU this ensures lower data transfer times.
Parameters
----------
shape: [int]
shape of the tensor.
name: str
name of the underlying placeholder
"""
super().__init__(tf.placeholder(tf.uint8, [None] + list(shape), name=name))
self._shape = shape
self._output = tf.cast(super().get(), tf.float32) / 255.0
def get(self):
return self._output
def ensure_tf_input(thing):
"""Takes either tf.placeholder of TfInput and outputs equivalent TfInput"""
if isinstance(thing, TfInput):
return thing
elif is_placeholder(thing):
return PlacholderTfInput(thing)
else:
raise ValueError("Must be a placeholder or TfInput")
# ================================================================
# Mathematical utils
# ================================================================
def huber_loss(x, delta=1.0):
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
return tf.where(
tf.abs(x) < delta,
tf.square(x) * 0.5,
delta * (tf.abs(x) - 0.5 * delta)
)
# ================================================================
# Optimizer utils
# ================================================================
def minimize_and_clip(optimizer, objective, var_list, clip_val=10):
"""Minimized `objective` using `optimizer` w.r.t. variables in
`var_list` while ensure the norm of the gradients for each
variable is clipped to `clip_val`
"""
if clip_val is None:
return optimizer.minimize(objective, var_list=var_list)
else:
gradients = optimizer.compute_gradients(objective, var_list=var_list)
for i, (grad, var) in enumerate(gradients):
if grad is not None:
gradients[i] = (tf.clip_by_norm(grad, clip_val), var)
return optimizer.apply_gradients(gradients)
# ================================================================
# Global session
# ================================================================
def get_session():
"""Returns recently made Tensorflow session"""
return tf.get_default_session()
def make_session(num_cpu):
"""Returns a session that will use <num_cpu> CPU's only"""
tf_config = tf.ConfigProto(
inter_op_parallelism_threads=num_cpu,
intra_op_parallelism_threads=num_cpu)
return tf.Session(config=tf_config)
def single_threaded_session():
"""Returns a session which will only use a single CPU"""
return make_session(1)
ALREADY_INITIALIZED = set()
def initialize():
"""Initialize all the uninitialized variables in the global scope."""
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
get_session().run(tf.variables_initializer(new_variables))
ALREADY_INITIALIZED.update(new_variables)
# ================================================================
# Scopes
# ================================================================
def scope_vars(scope, trainable_only=False):
"""
Get variables inside a scope
The scope can be specified as a string
Parameters
----------
scope: str or VariableScope
scope in which the variables reside.
trainable_only: bool
whether or not to return only the variables that were marked as trainable.
Returns
-------
vars: [tf.Variable]
list of variables in `scope`.
"""
return tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.GLOBAL_VARIABLES,
scope=scope if isinstance(scope, str) else scope.name
)
def scope_name():
"""Returns the name of current scope as a string, e.g. deepq/q_func"""
return tf.get_variable_scope().name
def absolute_scope_name(relative_scope_name):
"""Appends parent scope name to `relative_scope_name`"""
return scope_name() + "/" + relative_scope_name
# ================================================================
# Saving variables
# ================================================================
def load_state(fname, saver=None):
"""Load all the variables to the current session from the location <fname>"""
if saver is None:
saver = tf.train.Saver()
saver.restore(get_session(), fname)
return saver
def save_state(fname, saver=None):
"""Save all the variables in the current session to the location <fname>"""
os.makedirs(os.path.dirname(fname), exist_ok=True)
if saver is None:
saver = tf.train.Saver()
saver.save(get_session(), fname)
return saver
# ================================================================
# Theano-like Function
# ================================================================
def function(inputs, outputs, updates=None, givens=None):
"""Just like Theano function. Take a bunch of tensorflow placeholders and expersions
computed based on those placeholders and produces f(inputs) -> outputs. Function f takes
values to be feed to the inputs placeholders and produces the values of the experessions
in outputs.
Input values can be passed in the same order as inputs or can be provided as kwargs based
on placeholder name (passed to constructor or accessible via placeholder.op.name).
Example:
x = tf.placeholder(tf.int32, (), name="x")
y = tf.placeholder(tf.int32, (), name="y")
z = 3 * x + 2 * y
lin = function([x, y], z, givens={y: 0})
with single_threaded_session():
initialize()
assert lin(2) == 6
assert lin(x=3) == 9
assert lin(2, 2) == 10
assert lin(x=2, y=3) == 12
Parameters
----------
inputs: [tf.placeholder or TfInput]
list of input arguments
outputs: [tf.Variable] or tf.Variable
list of outputs or a single output to be returned from function. Returned
value will also have the same shape.
"""
if isinstance(outputs, list):
return _Function(inputs, outputs, updates, givens=givens)
elif isinstance(outputs, (dict, collections.OrderedDict)):
f = _Function(inputs, outputs.values(), updates, givens=givens)
return lambda *args, **kwargs: type(outputs)(zip(outputs.keys(), f(*args, **kwargs)))
else:
f = _Function(inputs, [outputs], updates, givens=givens)
return lambda *args, **kwargs: f(*args, **kwargs)[0]
class _Function(object):
def __init__(self, inputs, outputs, updates, givens, check_nan=False):
for inpt in inputs:
if not issubclass(type(inpt), TfInput):
assert len(inpt.op.inputs) == 0, "inputs should all be placeholders of rl_algs.common.TfInput"
self.inputs = inputs
updates = updates or []
self.update_group = tf.group(*updates)
self.outputs_update = list(outputs) + [self.update_group]
self.givens = {} if givens is None else givens
self.check_nan = check_nan
def _feed_input(self, feed_dict, inpt, value):
if issubclass(type(inpt), TfInput):
feed_dict.update(inpt.make_feed_dict(value))
elif is_placeholder(inpt):
feed_dict[inpt] = value
def __call__(self, *args, **kwargs):
assert len(args) <= len(self.inputs), "Too many arguments provided"
feed_dict = {}
# Update the args
for inpt, value in zip(self.inputs, args):
self._feed_input(feed_dict, inpt, value)
# Update the kwargs
kwargs_passed_inpt_names = set()
for inpt in self.inputs[len(args):]:
inpt_name = inpt.name.split(':')[0]
inpt_name = inpt_name.split('/')[-1]
assert inpt_name not in kwargs_passed_inpt_names, \
"this function has two arguments with the same name \"{}\", so kwargs cannot be used.".format(inpt_name)
if inpt_name in kwargs:
kwargs_passed_inpt_names.add(inpt_name)
self._feed_input(feed_dict, inpt, kwargs.pop(inpt_name))
else:
assert inpt in self.givens, "Missing argument " + inpt_name
assert len(kwargs) == 0, "Function got extra arguments " + str(list(kwargs.keys()))
# Update feed dict with givens.
for inpt in self.givens:
feed_dict[inpt] = feed_dict.get(inpt, self.givens[inpt])
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
if self.check_nan:
if any(np.isnan(r).any() for r in results):
raise RuntimeError("Nan detected")
return results