Skip to content

Commit

Permalink
Add new config items and support smooth quant (#514)
Browse files Browse the repository at this point in the history
Signed-off-by: wenhuach21 <wenhua.cheng@intel.com>
Signed-off-by: Mengni Wang <mengni.wang@intel.com>
  • Loading branch information
mengniwang95 committed Feb 16, 2023
1 parent 5c00d63 commit cbb5cf5
Show file tree
Hide file tree
Showing 15 changed files with 871 additions and 209 deletions.
181 changes: 141 additions & 40 deletions neural_compressor/adaptor/onnxrt.py

Large diffs are not rendered by default.

250 changes: 183 additions & 67 deletions neural_compressor/adaptor/ox_utils/calibration.py

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions neural_compressor/adaptor/ox_utils/quantizer.py
Expand Up @@ -38,7 +38,6 @@
from neural_compressor.adaptor.ox_utils.util import find_by_name, dtype_to_name
from neural_compressor.adaptor.ox_utils.util import __producer__, __version__
from neural_compressor.adaptor.ox_utils.util import quantize_data, dtype_mapping, support_pair, ValueInfo
from neural_compressor import options
from neural_compressor.model.onnx_model import ONNXModel
from neural_compressor.adaptor.ox_utils.operators import OPERATORS

Expand All @@ -48,7 +47,9 @@ class Quantizer:
"""Quantizer class."""

def __init__(self, model, q_config, mode, static, quantization_params,
op_types_to_quantize, fallback_list=['fp32'], reduce_range=None):
op_types_to_quantize, fallback_list=['fp32'], reduce_range=None,
add_qdq_pair_to_weight=False, optypes_to_exclude_output_quant=[],
dedicated_qdq_pair=False):
"""Initialization.
Args:
Expand All @@ -60,10 +61,13 @@ def __init__(self, model, q_config, mode, static, quantization_params,
op_types_to_quantize (list): optypes to quantize
fallback_list (list, optional): fallback data type. Defaults to ['fp32'].
reduce_range (bool, optional): use 7 bit or not. Defaults to None.
add_qdq_pair_to_weight (bool, optional): add QDQ pair to weight or not. Defaults to False.
optypes_to_exclude_output_quant (list, optional): optypes to exclude output quantization. Defaults to [].
dedicated_qdq_pair (bool, optional): dedicate QDQ pair or not. Defaults to False.
"""
self.model = ONNXModel(model) if not isinstance(model, ONNXModel) else model
model = onnx.shape_inference.infer_shapes(self.model.model) if \
not self.model.large_size else self.model.model
not self.model.is_large_model else self.model.model
self.config = q_config
self.reduce_range = reduce_range
self.mode = mode # QuantizationMode.Value
Expand Down Expand Up @@ -96,12 +100,10 @@ def __init__(self, model, q_config, mode, static, quantization_params,
if not self.static:
self.op_types_to_exclude_output_quantization = op_types_to_quantize
else:
self.op_types_to_exclude_output_quantization = [] \
if not options.onnxrt.qdq_setting.OpTypesToExcludeOutputQuantizatioin \
else options.onnxrt.qdq_setting.OpTypesToExcludeOutputQuantizatioin
self.op_types_to_exclude_output_quantization = optypes_to_exclude_output_quant

self.add_qdq_pair_to_weight = options.onnxrt.qdq_setting.AddQDQPairToWeight
self.dedicated_qdq_pair = options.onnxrt.qdq_setting.DedicatedQDQPair
self.add_qdq_pair_to_weight = add_qdq_pair_to_weight
self.dedicated_qdq_pair = dedicated_qdq_pair

def check_opset_version(self):
"""Check opset version."""
Expand Down
233 changes: 233 additions & 0 deletions neural_compressor/adaptor/ox_utils/util.py
Expand Up @@ -452,3 +452,236 @@ def find_by_name(name, item_list):
return items[0]
else:
return None

def get_smooth_scales_per_op(max_vals_per_channel, input_tensors_2_weights,
input_tensors_2_weights_nodes, alpha):
"""Get the smooth scales for weights.
The ops with the same input will share one mul layer.
TODO support individual scales for each layer.
Args:
max_vals_per_channel: Max values per channel after calibration
input_tensors_2_weights: A dict saved input tensor name and its corresponding weights
input_tensors_2_weights_nodes:A dict saved input tensor name and its corresponding weight nodes
alpha: smooth alpha in paper
Returns:
the smooth scales for weights, currently one input tensor only have one scale
"""
scales = {}
for key in input_tensors_2_weights_nodes.keys():
nodes = input_tensors_2_weights_nodes[key]
for index, node in enumerate(nodes):
name = node.name
weight = input_tensors_2_weights[key][index]
if len(weight.shape) == 4: # conv
if weight.shape[1] == 1: # depthwise conv
pass
else:
weight = np.moveaxis(weight, 0, 1)
weight = weight.reshape(weight.shape[0], -1)
weight_max_per_channel = np.amax(weight, axis=-1)
input_power = np.power(max_vals_per_channel[key], alpha)
weight_power = np.power(weight_max_per_channel, 1 - alpha)
scale = np.clip(input_power / weight_power, a_min=1e-5, a_max=None)
scales[name] = scale
return scales

def get_smooth_scales_per_input(max_vals_per_channel, input_tensors_2_weights, alpha):
"""Get the smooth scales for weights.
The ops with the same input will share one mul layer.
TODO support individual scales for each layer.
Args:
max_vals_per_channel: Max values per channel after calibration
input_tensors_2_weights: A dict saved input tensor name and its corresponding weights
alpha: smooth alpha in paper
Returns:
the smooth scales for weights, currently one input tensor only have one scale
"""
scales = {}
for key in input_tensors_2_weights.keys():
weights = input_tensors_2_weights[key]
weights_in_channel_max = []
for weight in weights: # mamul ic*oc, conv oc*ic*k*k
if len(weight.shape) == 4: # conv
if weight.shape[1] == 1: # depthwise conv
pass
else:
weight = np.moveaxis(weight, 0, 1)
weight = weight.reshape(weight.shape[0], -1)
cur_max = np.amax(weight, axis=-1)
weights_in_channel_max.append(cur_max)
weigths_stack = np.stack(weights_in_channel_max, axis=-1)
weigths_stack = np.abs(weigths_stack.reshape(weigths_stack.shape[0], -1))
weights_max = np.amax(weigths_stack, axis=-1)
input_power = np.power(max_vals_per_channel[key], alpha)
weight_power = np.power(weights_max, 1 - alpha)
scale = np.clip(input_power / weight_power, a_min=1e-5, a_max=None)
scales[key] = scale
return scales

def insert_smooth_mul_op_per_input(scales, shape_infos, input_tensors_2_weights_nodes):
"""Insert the mul layer after inupt.
The ops with the same input will share one mul layer.
Args:
scales: The smooth scales
shape_infos: the input tensor shape information
input_tensors_2_weights_nodes: A dict
Returns:
new_added_mul_nodes: added Mul layers
new_init_tensors: added scales tensor
"""
new_added_mul_nodes = []
new_init_tensors = [] # scales_tensor
for key in scales.keys():
scale_factor = 1.0 / scales[key]
shape_info = shape_infos[key]
if len(shape_info) == 3 or len(shape_info) == 2: # the last dim is input channel
pass
elif len(shape_info) == 4:
scale_factor = np.reshape(scale_factor, (1, -1, 1, 1))
else:
assert False, "not support"
name = key + "_" + "smooth_scale"
scale_tensor = helper.make_tensor(
name=name,
data_type=onnx_proto.TensorProto.FLOAT,
dims=scale_factor.shape,
vals=scale_factor.flatten().tolist())
new_init_tensors.append(scale_tensor)
mul_output_name = key + "_smooth_output"
mul_node = helper.make_node(
"Mul",
inputs=[key, key + "_" + "smooth_scale"],
outputs=[mul_output_name],
name=key + "_smooth_mul"
)
new_added_mul_nodes.append(mul_node)
for node in input_tensors_2_weights_nodes[key]:
for index, input in enumerate(node.input):
if input == key:
node.input[index] = mul_output_name
return new_added_mul_nodes, new_init_tensors

def adjust_weights_per_op(model, nodes, scales):
"""Adjust the weights per input scale.
Each op has one individual Mul layer.
Args:
model: The onnx model
nodes: The nodes whose weights needs to be adjustd
scales: The input scales
"""
name_to_indices = {}
for index, i in enumerate(model.model.graph.initializer):
name_to_indices[i.name] = index
for key in nodes.keys():
node = nodes[key]
input = node.input[1]
if input in name_to_indices.keys():
weight = numpy_helper.to_array(model.model.graph.initializer[name_to_indices[input]])
if len(weight.shape) == 2:
scale = np.expand_dims(scales[key],
axis=-1) # TODO, to support conv
new_weight = weight * scale
elif len(weight.shape) == 4: # TODO need to check conv
scale = np.reshape(scales[key], (1, -1, 1, 1))
new_weight = weight * scale
else:
assert False, "not support"
new_tensor = numpy_helper.from_array(new_weight, input)
model.model.graph.initializer[name_to_indices[input]].CopyFrom(new_tensor)

def adjust_weights_per_input(model, nodes, scales):
"""Adjust the weights per input scale.
The ops with the same input will share one mul layer
Args:
model: The onnx model
nodes: The nodes whose weights needs to be adjustd
scales: The input scales
"""
name_to_indices = {}
for index, i in enumerate(model.model.graph.initializer):
name_to_indices[i.name] = index
for key in nodes.keys():
curr_nodes = nodes[key]
for node in curr_nodes:
input = node.input[1] # TODO
if input in name_to_indices.keys():
weight = numpy_helper.to_array(model.model.graph.initializer[name_to_indices[input]])
if len(weight.shape) == 2:
scale = np.expand_dims(scales[key],
axis=-1) # TODO, to support conv
new_weight = weight * scale
elif len(weight.shape) == 4: # TODO need to check conv
scale = np.reshape(scales[key], (1, -1, 1, 1))
new_weight = weight * scale
else:
assert False, "not support"
new_tensor = numpy_helper.from_array(new_weight, input)
model.model.graph.initializer[name_to_indices[input]].CopyFrom(new_tensor)

def insert_smooth_mul_op_per_op(scales, shape_infos, input_tensors_2_weights_nodes):
"""Insert the mul layer before op.
Each op has one individual Mul layer.
Args:
scales: The smooth scales
shape_infos: the input tensor shape information
input_tensors_2_weights_nodes: A dict
Returns:
new_added_mul_nodes: added Mul layers
new_init_tensors: added scales tensor
name_2_nodes: a dict, key is the node name, value is the node
"""
name_2_nodes = {}
for key in input_tensors_2_weights_nodes.keys():
nodes = input_tensors_2_weights_nodes[key]
for node in nodes:
name_2_nodes[node.name] = node
new_added_mul_nodes = []
new_init_tensors = [] # scales_tensor
for input_key in input_tensors_2_weights_nodes.keys():
shape_info = shape_infos[input_key]
nodes = input_tensors_2_weights_nodes[input_key]
for node in nodes:
key = node.name
scale_factor = 1.0 / scales[key]
if len(shape_info) == 3 or len(shape_info) == 2: # the last dim is input channel
pass
elif len(shape_info) == 4:
scale_factor = np.reshape(scale_factor, (1, -1, 1, 1))
else:
assert False, "not support"
name = key + "_" + "smooth_scale"
scale_tensor = helper.make_tensor(
name=name,
data_type=onnx_proto.TensorProto.FLOAT,
dims=scale_factor.shape,
vals=scale_factor.flatten().tolist())
new_init_tensors.append(scale_tensor)
mul_output_name = key + "_smooth_output"
mul_node = helper.make_node(
"Mul",
inputs=[input_key, name],
outputs=[mul_output_name],
name=key + "_smooth_mul"
)
new_added_mul_nodes.append(mul_node)
node = name_2_nodes[key]
for index, input in enumerate(node.input):
if input == input_key:
node.input[index] = mul_output_name
return new_added_mul_nodes, new_init_tensors, name_2_nodes
63 changes: 63 additions & 0 deletions neural_compressor/algorithm/smooth_quant.py
@@ -0,0 +1,63 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Build SmoothQuant algorithm class."""

import numpy as np
from .algorithm import Algorithm, algorithm_registry
from ..utils import logger

@algorithm_registry(algorithm_type='smooth_quant')
class SmoothQuant(Algorithm):
"""SmoothQuant algorithm class."""
def __init__(self, percentile=99.999, op_types=['MatMul', 'Linear', 'Conv'],
scales_per_op=True):
"""Initialize SmoothQuant class.
Args:
percentile:Percentile of calibration to remove outliers
op_types: The op types whose input tensor will be dumped
scales_per_op: True, each op will have an individual scale, mainly for accuracy
False, ops with the same input will share a scale, mainly for performance
"""
self.percentile = percentile
self.op_types = op_types
self.scales_per_op = scales_per_op
self.alpha = 1.0
self.tune_cfg = None

def __call__(self, origin_model, q_model, adaptor, dataloader, iterations):
"""Return the processed model via SmoothQuant algorithm.
Fake input channel quantization, for more details please refer to:
[1] SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models
[2] SPIQ: Data-Free Per-Channel Static Input Quantization
inert Mul op before each conv/matmul with adjusted weights
Args:
origin_model: origin_model
q_model: q_model
adaptor: adaptor
dataloader: dataloader
iterations: iterations
Returns:
model: A modified onnx model
"""
q_model = adaptor.smooth_quant(origin_model, dataloader, iterations, self.tune_cfg, self.alpha,
self.percentile, self.op_types, self.scales_per_op)
return q_model
2 changes: 2 additions & 0 deletions neural_compressor/conf/config.py
Expand Up @@ -1399,6 +1399,8 @@ def map_pyconfig_to_cfg(self, pythonic_config):
'model.outputs': pythonic_config.quantization.outputs,
'model.backend': pythonic_config.quantization.backend,
'model.quant_format': pythonic_config.quantization.quant_format,
'model.domain': pythonic_config.quantization.domain,
'quantization.recipes': pythonic_config.quantization.recipes,
'quantization.approach': pythonic_config.quantization.approach,
'quantization.calibration.sampling_size':
pythonic_config.quantization.calibration_sampling_size,
Expand Down
3 changes: 1 addition & 2 deletions neural_compressor/conf/pythonic_config.py
Expand Up @@ -63,7 +63,7 @@ def __init__(self,
accuracy_criterion=accuracy_criterion,
quant_level=quant_level
)
self._approach = approach
self.approach = approach

@property
def approach(self):
Expand All @@ -77,7 +77,6 @@ def approach(self, approach):
):
self._approach = approach


class WeightConf:
def __init__(self, datatype=None, scheme=None, granularity=None, algorithm=None):
self._datatype = datatype
Expand Down

0 comments on commit cbb5cf5

Please sign in to comment.