/
common.py
349 lines (269 loc) · 11.7 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
""" test common utilities."""
import argparse
import os
import sys
import unittest
from collections import defaultdict
from distutils.version import LooseVersion
from parameterized import parameterized
import numpy as np
import tensorflow as tf
from tf2onnx import constants, logging, utils
# pylint: disable=import-outside-toplevel
__all__ = [
"TestConfig",
"get_test_config",
"unittest_main",
"check_onnxruntime_backend",
"check_tf_min_version",
"check_tf_max_version",
"skip_tf_versions",
"skip_tf_cpu",
"check_onnxruntime_min_version",
"check_opset_min_version",
"check_opset_max_version",
"check_target",
"skip_caffe2_backend",
"skip_onnxruntime_backend",
"skip_opset",
"check_onnxruntime_incompatibility",
"validate_const_node",
"group_nodes_by_type",
"test_ms_domain",
"check_node_domain",
"check_op_count"
]
# pylint: disable=missing-docstring
class TestConfig(object):
def __init__(self):
self.platform = sys.platform
self.tf_version = utils.get_tf_version()
self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", constants.PREFERRED_OPSET))
self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(constants.DEFAULT_TARGET)).split(',')
self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime")
self.backend_version = self._get_backend_version()
self.log_level = logging.WARNING
self.temp_dir = utils.get_temp_directory()
@property
def is_mac(self):
return self.platform == "darwin"
@property
def is_onnxruntime_backend(self):
return self.backend == "onnxruntime"
@property
def is_caffe2_backend(self):
return self.backend == "caffe2"
@property
def is_debug_mode(self):
return utils.is_debug_mode()
def _get_backend_version(self):
version = None
if self.backend == "onnxruntime":
import onnxruntime as ort
version = ort.__version__
elif self.backend == "caffe2":
# TODO: get caffe2 version
pass
if version:
version = LooseVersion(version)
return version
def __str__(self):
return "\n\t".join(["TestConfig:",
"platform={}".format(self.platform),
"tf_version={}".format(self.tf_version),
"opset={}".format(self.opset),
"target={}".format(self.target),
"backend={}".format(self.backend),
"backend_version={}".format(self.backend_version),
"is_debug_mode={}".format(self.is_debug_mode),
"temp_dir={}".format(self.temp_dir)])
@staticmethod
def load():
config = TestConfig()
# if not launched by pytest, parse console arguments to override config
if "pytest" not in sys.argv[0]:
parser = argparse.ArgumentParser()
parser.add_argument("--backend", default=config.backend,
choices=["caffe2", "onnxruntime"],
help="backend to test against")
parser.add_argument("--opset", type=int, default=config.opset, help="opset to test against")
parser.add_argument("--target", default=",".join(config.target), choices=constants.POSSIBLE_TARGETS,
help="target platform")
parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count")
parser.add_argument("--debug", help="output debugging information", action="store_true")
parser.add_argument("--temp_dir", help="temp dir")
parser.add_argument("unittest_args", nargs='*')
args = parser.parse_args()
if args.debug:
utils.set_debug_mode(True)
config.backend = args.backend
config.opset = args.opset
config.target = args.target.split(',')
config.log_level = logging.get_verbosity_level(args.verbose, config.log_level)
if args.temp_dir:
config.temp_dir = args.temp_dir
# Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone)
sys.argv[1:] = args.unittest_args
return config
# need to load config BEFORE main is executed when launched from script
# otherwise, it will be too late for test filters to take effect
_config = TestConfig.load()
def get_test_config():
global _config
return _config
def unittest_main():
config = get_test_config()
logging.basicConfig(level=config.log_level)
with logging.set_scope_level(logging.INFO) as logger:
logger.info(config)
unittest.main()
def _append_message(reason, message):
if message:
reason = reason + ": " + message
return reason
def check_tf_max_version(max_accepted_version, message=""):
""" Skip if tf_version > max_required_version """
config = get_test_config()
reason = _append_message("conversion requires tf <= {}".format(max_accepted_version), message)
return unittest.skipIf(config.tf_version > LooseVersion(max_accepted_version), reason)
def check_tf_min_version(min_required_version, message=""):
""" Skip if tf_version < min_required_version """
config = get_test_config()
reason = _append_message("conversion requires tf >= {}".format(min_required_version), message)
return unittest.skipIf(config.tf_version < LooseVersion(min_required_version), reason)
def skip_tf_versions(excluded_versions, message=""):
""" Skip if tf_version SEMANTICALLY matches any of excluded_versions. """
config = get_test_config()
condition = False
reason = _append_message("conversion excludes tf {}".format(excluded_versions), message)
current_tokens = str(config.tf_version).split('.')
for excluded_version in excluded_versions:
exclude_tokens = excluded_version.split('.')
# assume len(exclude_tokens) <= len(current_tokens)
for i, exclude in enumerate(exclude_tokens):
if not current_tokens[i] == exclude:
break
condition = True
return unittest.skipIf(condition, reason)
def is_tf_gpu():
return tf.test.is_gpu_available()
def skip_tf_cpu(message=""):
is_tf_cpu = not is_tf_gpu()
return unittest.skipIf(is_tf_cpu, message)
def check_opset_min_version(min_required_version, message=""):
""" Skip if opset < min_required_version """
config = get_test_config()
reason = _append_message("conversion requires opset >= {}".format(min_required_version), message)
return unittest.skipIf(config.opset < min_required_version, reason)
def check_opset_max_version(max_accepted_version, message=""):
""" Skip if opset > max_accepted_version """
config = get_test_config()
reason = _append_message("conversion requires opset <= {}".format(max_accepted_version), message)
return unittest.skipIf(config.opset > max_accepted_version, reason)
def skip_opset(opset_v, message=""):
""" Skip if opset = opset_v """
config = get_test_config()
reason = _append_message("conversion requires opset != {}".format(opset_v), message)
return unittest.skipIf(config.opset == opset_v, reason)
def check_target(required_target, message=""):
""" Skip if required_target is NOT specified """
config = get_test_config()
reason = _append_message("conversion requires target {} specified".format(required_target), message)
return unittest.skipIf(required_target not in config.target, reason)
def skip_onnxruntime_backend(message=""):
""" Skip if backend is onnxruntime """
config = get_test_config()
reason = _append_message("not supported by onnxruntime", message)
return unittest.skipIf(config.is_onnxruntime_backend, reason)
def check_onnxruntime_backend(message=""):
""" Skip if backend is NOT onnxruntime """
config = get_test_config()
reason = _append_message("only supported by onnxruntime", message)
return unittest.skipIf(not config.is_onnxruntime_backend, reason)
def check_onnxruntime_min_version(min_required_version, message=""):
""" Skip if onnxruntime version < min_required_version """
config = get_test_config()
reason = _append_message("conversion requires onnxruntime >= {}".format(min_required_version), message)
return unittest.skipIf(config.is_onnxruntime_backend and
config.backend_version < LooseVersion(min_required_version), reason)
def skip_caffe2_backend(message=""):
""" Skip if backend is caffe2 """
config = get_test_config()
reason = _append_message("not supported by caffe2", message)
return unittest.skipIf(config.is_caffe2_backend, reason)
def check_onnxruntime_incompatibility(op):
""" Skip if backend is onnxruntime AND op is NOT supported in current opset """
config = get_test_config()
if not config.is_onnxruntime_backend:
return unittest.skipIf(False, None)
support_since = {
"Abs": 6, # Abs-1
"Add": 7, # Add-1, Add-6
"AveragePool": 7, # AveragePool-1
"Div": 7, # Div-1, Div-6
"Elu": 6, # Elu-1
"Equal": 7, # Equal-1
"Exp": 6, # Exp-1
"Greater": 7, # Greater-1
"Less": 7, # Less-1
"Log": 6, # Log-1
"Max": 6, # Max-1
"Min": 6, # Min-1
"Mul": 7, # Mul-1, Mul-6
"Neg": 6, # Neg-1
"Pow": 7, # Pow-1
"Reciprocal": 6, # Reciprocal-1
"Relu": 6, # Relu-1
"Sqrt": 6, # Sqrt-1
"Sub": 7, # Sub-1, Sub-6
"Tanh": 6, # Tanh-1
}
if op not in support_since or config.opset >= support_since[op]:
return unittest.skipIf(False, None)
reason = "{} is not supported by onnxruntime before opset {}".format(op, support_since[op])
return unittest.skipIf(True, reason)
def validate_const_node(node, expected_val):
if node.is_const():
node_val = node.get_tensor_value()
np.testing.assert_allclose(expected_val, node_val)
return True
return False
def group_nodes_by_type(graph):
res = defaultdict(list)
for node in graph.get_nodes():
attr_body_graphs = node.get_body_graphs()
if attr_body_graphs:
for _, body_graph in attr_body_graphs.items():
body_graph_res = group_nodes_by_type(body_graph)
for k, v in body_graph_res.items():
res[k].extend(v)
res[node.type].append(node)
return res
def check_op_count(graph, op_type, expected_count):
return len(group_nodes_by_type(graph)[op_type]) == expected_count
def check_lstm_count(graph, expected_count):
return check_op_count(graph, "LSTM", expected_count)
def check_gru_count(graph, expected_count):
return check_op_count(graph, "GRU", expected_count)
_MAX_MS_OPSET_VERSION = 1
def test_ms_domain(versions=None):
""" Parameterize test case to apply ms opset(s) as extra_opset. """
@check_onnxruntime_backend()
def _custom_name_func(testcase_func, param_num, param):
del param_num
arg = param.args[0]
return "%s_%s" % (testcase_func.__name__, arg.version)
# Test all opset versions in ms domain if versions is not specified
if versions is None:
versions = list(range(1, _MAX_MS_OPSET_VERSION + 1))
opsets = []
for version in versions:
opsets.append([utils.make_opsetid(constants.MICROSOFT_DOMAIN, version)])
return parameterized.expand(opsets, testcase_func_name=_custom_name_func)
def check_node_domain(node, domain):
# None or empty string means onnx domain
if not domain:
return not node.domain
return node.domain == domain