-
Notifications
You must be signed in to change notification settings - Fork 296
/
attr_converter.py
89 lines (79 loc) · 2.44 KB
/
attr_converter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from onnx_tf.common import IS_PYTHON3
from onnx_tf.common.legacy import legacy_opset_pre_ver
def convert_tf(attr):
return __convert_tf_attr_value(attr)
def convert_onnx(attr):
return __convert_onnx_attribute_proto(attr)
def __convert_tf_attr_value(attr):
""" convert Tensorflow AttrValue object to Python object
"""
if attr.HasField('list'):
return __convert_tf_list_value(attr.list)
if attr.HasField('s'):
return attr.s
elif attr.HasField('i'):
return attr.i
elif attr.HasField('f'):
return attr.f
elif attr.HasField('b'):
return attr.b
elif attr.HasField('type'):
return attr.type
elif attr.HasField('shape'):
return attr.type
elif attr.HasField('tensor'):
return attr.tensor
else:
raise ValueError("Unsupported Tensorflow attribute: {}".format(attr))
def __convert_tf_list_value(list_value):
""" convert Tensorflow ListValue object to Python object
"""
if list_value.s:
return list_value.s
elif list_value.i:
return list_value.i
elif list_value.f:
return list_value.f
elif list_value.b:
return list_value.b
elif list_value.tensor:
return list_value.tensor
elif list_value.type:
return list_value.type
elif list_value.shape:
return list_value.shape
elif list_value.func:
return list_value.func
else:
raise ValueError("Unsupported Tensorflow attribute: {}".format(list_value))
def __convert_onnx_attribute_proto(attr_proto):
"""
Convert an ONNX AttributeProto into an appropriate Python object
for the type.
NB: Tensor attribute gets returned as the straight proto.
"""
if attr_proto.HasField('f'):
return attr_proto.f
elif attr_proto.HasField('i'):
return attr_proto.i
elif attr_proto.HasField('s'):
return str(attr_proto.s, 'utf-8') if IS_PYTHON3 else attr_proto.s
elif attr_proto.HasField('t'):
return attr_proto.t # this is a proto!
elif attr_proto.HasField('g'):
return attr_proto.g
elif attr_proto.floats:
return list(attr_proto.floats)
elif attr_proto.ints:
return list(attr_proto.ints)
elif attr_proto.strings:
str_list = list(attr_proto.strings)
if IS_PYTHON3:
str_list = list(map(lambda x: str(x, 'utf-8'), str_list))
return str_list
elif attr_proto.HasField('sparse_tensor'):
return attr_proto.sparse_tensor
elif not legacy_opset_pre_ver(15) and attr_proto.HasField('tp'):
return attr_proto.tp
else:
raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto))