-
Notifications
You must be signed in to change notification settings - Fork 298
/
backend.py
422 lines (361 loc) · 15.5 KB
/
backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
"""Backend for running ONNX on Tensorflow
To run this, you will need to have Tensorflow installed as well.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
try:
from itertools import izip as zip
except ImportError: # will be 3.x series
pass
from onnx import defs
from onnx import numpy_helper
from onnx.backend.base import Backend
from onnx.backend.base import namedtupledict
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt
from onnx.helper import make_opsetid
import tensorflow as tf
import numpy as np
from onnx_tf.backend_rep import TensorflowRep
from onnx_tf.common import data_type
from onnx_tf.common import get_unique_suffix
from onnx_tf.common import supports_device as common_supports_device
from onnx_tf.common.handler_helper import get_all_backend_handlers
from onnx_tf.pb_wrapper import OnnxNode
from onnx_tf.backend_tf_module import BackendTFModule, TFModule
import onnx_tf.common as common
training_flag_name = "_onnx_tf_internal_is_training"
class TensorflowBackend(Backend):
""" Tensorflow Backend for ONNX
"""
@classmethod
def prepare(cls,
model,
device='CPU',
strict=True,
logging_level='INFO',
auto_cast=False,
**kwargs):
"""Prepare an ONNX model for Tensorflow Backend.
This function converts an ONNX model to an internel representation
of the computational graph called TensorflowRep and returns
the converted representation.
:param model: The ONNX model to be converted.
:param device: The device to execute this model on. It can be either CPU (default) or CUDA.
:param strict: Whether to enforce semantic equivalence between the original model
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
Changing to False is strongly discouraged.
Currently, the strict flag only affects the behavior of MaxPool and AveragePool ops.
:param logging_level: The logging level, default is INFO. Change it to DEBUG
to see more conversion details or to WARNING to see less
:param auto_cast: Whether to auto cast data types that might lose precision for the tensors
with types not natively supported by Tensorflow, default is False
:returns: A TensorflowRep class object representing the ONNX model
"""
super(TensorflowBackend, cls).prepare(model, device, **kwargs)
common.logger.setLevel(logging_level)
common.logger.handlers[0].setLevel(logging_level)
common.sys_config.auto_cast = auto_cast
common.sys_config.device = device
return cls.onnx_model_to_tensorflow_rep(model, strict, **kwargs)
@classmethod
def onnx_model_to_tensorflow_rep(cls, model, strict, **kwargs):
""" Convert ONNX model to TensorflowRep.
:param model: ONNX ModelProto object.
:param strict: whether to enforce semantic equivalence between the original model
and the converted tensorflow model.
:return: TensorflowRep object.
"""
# Models with IR_VERSION less than 3 does not have opset_import set.
# We default to minimum opset, this behavior is consistent with
# onnx checker.
# c.f. https://github.com/onnx/onnx/blob/427ac0c1b792363d373e3d7e4eef97fa46458420/onnx/checker.cc#L478
if model.ir_version < 3:
opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)]
else:
opset_import = model.opset_import
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict,
**kwargs)
@classmethod
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
""" Convert ONNX graph to TensorflowRep.
:param graph_def: ONNX GraphProto object.
:param opset: ONNX OperatorSetIdProto list.
:param strict: whether to enforce semantic equivalence between the original model
and the converted tensorflow model.
:kwargs: additional arguements to generate tensor_dict for model debugging
:return: TensorflowRep object.
"""
# To generate tensor_dict or not, default is False
gen_tensor_dict = kwargs[
'gen_tensor_dict'] if 'gen_tensor_dict' in kwargs else False
# User provided input tensors, in the case the model inputs have unknown shapes
input_tensor_dict = kwargs[
'input_tensor_dict'] if 'input_tensor_dict' in kwargs else dict()
training_mode = kwargs[
'training_mode'] if 'training_mode' in kwargs else False
handlers = cls._get_handlers(opset)
# initializer: TensorProtos representing the values to initialize
# a given tensor.
# initialized: A list of names of the initialized tensors.
if graph_def.initializer:
initialized = {init.name for init in graph_def.initializer}
else:
initialized = set()
input_dict = dict()
module = BackendTFModule(handlers, opset, strict, graph_def, cls)
signatures = dict()
if training_mode:
tf_rep_graph = kwargs['graph'] if 'graph' in kwargs else tf.Graph()
else:
tf_rep_graph = tf.Graph()
with tf_rep_graph.as_default():
for value_info in graph_def.input:
if value_info.name in initialized or not value_info.type.HasField(
'tensor_type'):
continue
shape = list(
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
for d in value_info.type.tensor_type.shape.dim)
value_info_name = value_info.name.replace(
":", "_tf_") + "_" + get_unique_suffix(
) if ":" in value_info.name else value_info.name
tf_spec = tf.TensorSpec(
shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type),
value_info_name)
signatures[value_info.name] = tf_spec
if gen_tensor_dict or training_mode:
x = tf.compat.v1.placeholder(
data_type.onnx2tf(value_info.type.tensor_type.elem_type),
name=value_info_name,
shape=shape
) if value_info.name not in input_tensor_dict else input_tensor_dict[
value_info.name]
input_dict[value_info.name] = x
if gen_tensor_dict or training_mode:
input_dict_items = cls._onnx_initializer_to_input_dict_items(
graph_def.initializer, training_mode=True)
tensor_dict = dict(input_dict)
tensor_dict.update(input_dict_items)
tensor_dict[training_flag_name] = tf.compat.v1.placeholder_with_default(
False, shape=[])
for node in graph_def.node:
onnx_node = OnnxNode(node)
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
tensor_dict,
handlers,
opset=opset,
strict=strict)
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
tensor_dict.update(curr_node_output_map)
tf_rep = TensorflowRep()
tf_rep.inputs = [
value_info.name
for value_info in graph_def.input
if value_info.name not in initialized
]
tf_rep.outputs = [value_info.name for value_info in graph_def.output]
module.outputs = tf_rep.outputs
tf_rep.tf_module = module
tf_rep.signatures = signatures
if gen_tensor_dict or training_mode:
tf_rep.tensor_dict = tensor_dict
if training_mode:
tf_rep.graph = tf_rep_graph
tf_rep.onnx_op_list = cls._get_onnx_op_list(graph_def)
return tf_rep
@classmethod
def _get_onnx_op_list(cls, graph_def):
""" Get ONNX operator counts of the model.
:param graph_def: ONNX GraphProto object.
:return: Dictionary of all operators counts in the model.
"""
def get_onnx_op_from_graph_and_subgraph(graph, op_list):
for node in graph.node:
op_list[node.op_type] = 1 if node.op_type not in op_list.keys(
) else op_list[node.op_type] + 1
if node.op_type in ['Loop', 'Scan']:
onnx_node = OnnxNode(node)
body = onnx_node.attrs["body"]
op_list = get_onnx_op_from_graph_and_subgraph(body, op_list)
elif node.op_type == 'If':
onnx_node = OnnxNode(node)
then_branch = onnx_node.attrs['then_branch']
op_list = get_onnx_op_from_graph_and_subgraph(then_branch, op_list)
else_branch = onnx_node.attrs['else_branch']
op_list = get_onnx_op_from_graph_and_subgraph(else_branch, op_list)
return op_list
op_list = get_onnx_op_from_graph_and_subgraph(graph_def, dict())
sorted_op_list = dict()
for key in sorted(op_list):
sorted_op_list[key] = op_list[key]
return sorted_op_list
@classmethod
def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
""" Run ONNX node.
:param node: ONNX NodeProto object.
:param inputs: Inputs.
:param device: Device run on.
:param outputs_info: None.
:param kwargs: Other args.
:return: Outputs.
"""
super(TensorflowBackend, cls).run_node(node, inputs, device)
common.sys_config.device = device
node = OnnxNode(node)
input_tensors = []
for i in inputs:
if i is None:
input_tensors.append(i)
else:
input_tensors.append(tf.constant(i))
if isinstance(inputs, dict):
feed_dict_raw = inputs
else:
assert len(node.inputs) == len(inputs)
feed_dict_raw = dict(zip(node.inputs, inputs))
# TODO: is constant the best way for feeding inputs?
input_dict = {}
for k, v in feed_dict_raw.items():
if isinstance(v, list):
list_input = []
for x in v:
if x is None:
list_input.append(x)
else:
list_input.append(tf.constant(x))
input_dict[k] = list_input
elif v is None: # keep None for empty optional data
input_dict[k] = v
else:
input_dict[k] = tf.constant(v)
module = TFModule(node, cls)
output_vals = module(**input_dict)
output_vals = [
val.numpy() if isinstance(val, tf.Tensor) else val
for val in output_vals
]
return namedtupledict('Outputs', node.outputs)(*output_vals)
@classmethod
def _onnx_initializer_to_input_dict_items(cls,
initializer,
training_mode=False):
""" Convert ONNX graph initializer to input dict items.
:param initializer: ONNX graph initializer, list of TensorProto.
:return: List of input dict items.
"""
def tensor2list(onnx_tensor):
# Use the onnx.numpy_helper because the data may be raw
return numpy_helper.to_array(onnx_tensor).flatten().tolist()
def validate_initializer_name(name):
# Prepend a unique suffix if leading character is "_"
name = get_unique_suffix() + name if name[0] == "_" else name
# Replace ":" with "_tf_" and append a unique suffix for
# traceability
return name.replace(
":", "_tf_") + "_" + get_unique_suffix() if ":" in name else name
if training_mode:
tensor_dict = [
(init.name,
tf.Variable(np.array(tensor2list(init)).reshape(init.dims),
shape=init.dims,
dtype=data_type.onnx2tf(init.data_type),
name=validate_initializer_name(init.name)))
for init in initializer
]
else:
tensor_dict = [(init.name,
tf.constant(tensor2list(init),
shape=init.dims,
dtype=data_type.onnx2tf(init.data_type),
name=validate_initializer_name(init.name)))
for init in initializer]
return tensor_dict
@classmethod
def _onnx_node_to_tensorflow_op(cls,
node,
tensor_dict,
handlers=None,
opset=None,
strict=True):
"""
Convert onnx node to tensorflow op.
Args:
node: Onnx node object.
tensor_dict: Tensor dict of graph.
opset: Opset version of the operator set. Default 0 means using latest version.
strict: whether to enforce semantic equivalence between the original model
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
Changing to False is strongly discouraged.
Returns:
Tensorflow op
"""
handlers = handlers or cls._get_handlers(opset)
if handlers:
handler = handlers[node.domain].get(
node.op_type, None) if node.domain in handlers else None
if handler:
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(
node.op_type))
@classmethod
def _get_handlers(cls, opset):
""" Get all backend handlers with opset.
:param opset: ONNX OperatorSetIdProto list.
:return: All backend handlers.
"""
opset = opset or [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
opset_dict = dict([(o.domain, o.version) for o in opset])
return get_all_backend_handlers(opset_dict)
@classmethod
def supports_device(cls, device):
return common_supports_device(device)
@classmethod
def onnx_graph_to_tensorflow_ops(cls,
subgraph,
tensor_dict,
opset=None,
strict=True):
"""
Converts ONNX graph to Tensorflow operations
Args:
subgraph: the ONNX graph to be converted.
tensor_dict: tensor dict of the subgraph.
opset: opset version of the operator set.
strict: whether to enforce semantic equivalence between the
original model and the converted tensorflow model,
defaults to True (yes, enforce semantic equivalence).
Returns:
array of Tensorflow Tensors
"""
for node in subgraph.node:
onnx_node = OnnxNode(node)
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
tensor_dict,
opset=opset,
strict=strict)
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
tensor_dict.update(curr_node_output_map)
return tensor_dict
@classmethod
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True, **kwargs):
"""
Converts ONNX graph to TensorflowRep
Args:
graph_def: the ONNX graph to be converted
strict: whether to enforce semantic equivalence between the
original model and the converted tensorflow model,
defaults to True (yes, enforce semantic equivalence).
Returns:
TensorflowRep object.
"""
# get the opset of the installed ONNX
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict, **kwargs)
prepare = TensorflowBackend.prepare
run_node = TensorflowBackend.run_node
run_model = TensorflowBackend.run_model
supports_device = TensorflowBackend.supports_device
onnx_graph_to_tensorflow_ops = TensorflowBackend.onnx_graph_to_tensorflow_ops
onnx_graph_to_tensorflow_rep = TensorflowBackend.onnx_graph_to_tensorflow_rep