-
Notifications
You must be signed in to change notification settings - Fork 243
/
attention.py
119 lines (104 loc) · 4.32 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import onnx
from neural_compressor.adaptor.ox_utils.operators.ops import op_registry, Operator, qop_registry, QOperator
from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg, ms_domain, find_by_name
@op_registry(op_types="Attention")
class AttentionOperator(Operator):
def __init__(self, onnx_quantizer, onnx_node):
super(AttentionOperator, self).__init__(onnx_quantizer, onnx_node)
def quantize(self):
node = self.node
self.quantizer.quantize_inputs(node, [0, 1, 2])
node.name = node.name + "_quant"
def convert_check(self, convert_format):
node = self.node
assert convert_format in ['dynamic', 'static'], \
"convert format for {} should be in ['dynamic', 'static']".format(node.op_type)
if not node.name.endswith('_quant'):
return False
return True
def convert(self, convert_format):
node = self.node
parents = self.quantizer.model.get_parents(node)
quantized_name = []
scale = []
zp = []
for parent in parents[:2]:
if parent.op_type == 'DequantizeLinear':
quantized_name.append(parent.input[0])
scale.append(parent.input[1])
zp.append(parent.input[2])
self.quantizer.remove_nodes.append(parent)
elif parent.op_type == 'DynamicQuantizeLinear':
quantized_name.append(parent.output[0])
scale.append(parent.output[1])
zp.append(parent.output[2])
inputs = []
inputs.extend(quantized_name)
inputs.append(node.input[2])
inputs.extend(scale)
inputs.append(node.input[3] if len(node.input) > 3 else "")
inputs.extend(zp)
if len(node.input) > 4:
inputs.append(node.input[4])
kwargs = {}
for attribute in node.attribute: # pragma: no cover
kwargs.update(attribute_to_kwarg(attribute))
kwargs["domain"] = ms_domain
qattention_node = onnx.helper.make_node("QAttention", inputs, node.output,
node.name, **kwargs)
self.quantizer.new_nodes.append(qattention_node)
self.quantizer.remove_nodes.append(node)
@qop_registry(op_types="QAttention")
class QAttentionOperator(QOperator):
def __init__(self, onnx_node, children, initializers):
super().__init__(onnx_node, children, initializers)
def convert(self):
node = self.node
add_nodes = []
inputs = []
inits = []
if find_by_name(node.input[3], self.initializers) is None:
return False, add_nodes, inits
# input dq
in_dq1 = onnx.helper.make_node(
'DequantizeLinear',
[node.input[0], node.input[3], node.input[6]],
[node.name + '_in_dequant1'],
node.name + '_in_dequant1')
in_dq2 = onnx.helper.make_node(
'DequantizeLinear',
[node.input[1], node.input[4], node.input[7]],
[node.name + '_in_dequant2'],
node.name + '_in_dequant2')
inputs = [node.name + '_in_dequant1',
node.name + '_in_dequant2',
node.input[2],
node.input[5]]
add_nodes.extend([in_dq1, in_dq2])
outputs = node.output
kwargs = {}
for attribute in node.attribute: # pragma: no cover
kwargs.update(attribute_to_kwarg(attribute))
kwargs["domain"] = ms_domain
binary_node = onnx.helper.make_node(
'Attention', inputs,
outputs, node.name + '_convert', **kwargs)
add_nodes.append(binary_node)
return True, add_nodes, inits