Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samples Update & Bug fix #207

Merged
merged 2 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions ppq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# This file defines export functions & class of PPQ.
from ppq.api.setting import (ActivationQuantizationSetting, DispatchingTable,
EqualizationSetting, LSQSetting,
GraphFormatSetting, ParameterQuantizationSetting,
EqualizationSetting, GraphFormatSetting,
LSQSetting, ParameterQuantizationSetting,
QuantizationFusionSetting, QuantizationSetting,
QuantizationSettingFactory, TemplateSetting)
from ppq.core import *
from ppq.executor import TorchExecutor, TorchQuantizeDelegator
from ppq.IR import (BaseGraph, GraphCommand, GraphFormatter, Operation,
QuantableGraph, SearchableGraph, Variable)
from ppq.executor import (BaseGraphExecutor, TorchExecutor,
TorchQuantizeDelegator)
from ppq.IR import (BaseGraph, GraphBuilder, GraphCommand, GraphExporter,
GraphFormatter, Operation, QuantableGraph, SearchableGraph,
Variable)
from ppq.IR.deploy import RunnableGraph
from ppq.IR.quantize import QuantableOperation, QuantableVariable
from ppq.IR.search import SearchableGraph
Expand Down
11 changes: 3 additions & 8 deletions ppq/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@
ORT_PerChannelQuantizer,
ORT_PerTensorQuantizer,
PPL_DSP_Quantizer,
PPL_DSP_TI_Quantizer,
PPLCUDA_INT4_Quantizer,
PPLCUDAMixPrecisionQuantizer,
PPLCUDAQuantizer, TensorRTQuantizer,
TengineQuantizer)
PPL_DSP_TI_Quantizer, PPLCUDAQuantizer,
TensorRTQuantizer, TengineQuantizer)
from ppq.scheduler import DISPATCHER_TABLE, GraphDispatcher
from ppq.scheduler.perseus import Perseus
from torch.utils.data import DataLoader
Expand All @@ -48,8 +45,6 @@
# TargetPlatform.ORT_OOS_INT8: ORT_PerChannelQuantizer,
TargetPlatform.PPL_CUDA_INT8: PPLCUDAQuantizer,
TargetPlatform.EXTENSION: ExtQuantizer,
TargetPlatform.PPL_CUDA_MIX: PPLCUDAMixPrecisionQuantizer,
TargetPlatform.PPL_CUDA_INT4: PPLCUDA_INT4_Quantizer,
TargetPlatform.ACADEMIC_INT8: ACADEMICQuantizer,
TargetPlatform.ACADEMIC_INT4: ACADEMIC_INT4_Quantizer,
TargetPlatform.ACADEMIC_MIX: ACADEMIC_Mix_Quantizer,
Expand Down Expand Up @@ -83,7 +78,7 @@
TargetPlatform.METAX_INT8_T: ONNXRUNTIMExporter,
TargetPlatform.TRT_INT8: TensorRTExporter,
TargetPlatform.NCNN_INT8: NCNNExporter,
TargetPlatform.TENGINE_INT8: TengineExporter,
TargetPlatform.TENGINE_INT8: TengineExporter
}

# 为你的导出模型取一个好听的后缀名
Expand Down
23 changes: 17 additions & 6 deletions ppq/core/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class TargetPlatform(Enum):
NCNN_INT8 = 102
OPENVINO_INT8 = 103
TENGINE_INT8 = 104

PPL_CUDA_INT8 = 201
PPL_CUDA_INT4 = 202
PPL_CUDA_FP16 = 203
Expand All @@ -75,7 +75,7 @@ class TargetPlatform(Enum):

HEXAGON_INT8 = 801



FP32 = 0
# SHAPE-OR-INDEX related operation
Expand Down Expand Up @@ -366,7 +366,7 @@ def __init__(
offset: Any = None,
observer_algorithm: str = None,
detail: Any = None,
inplace: bool = False,
require_export: bool = None,
state: QuantizationStates = QuantizationStates.INITIAL
):
"""Create a PPQ Tensor Quantization Configuration Instance.
Expand Down Expand Up @@ -396,7 +396,7 @@ def __init__(
detail (Any, optional): Only used by PPQ internal logic, detail is used to store some internal data,
you are not supposed to use it.

inplace (bool, optional): Indicates whether quantization is taken inplace(rewrite tensor value).
require_export (bool, optional): If require_export == True, PPQ exporter will export this TQC ignoring state checks.

state (QuantizationStates, optional):
Defaults to QuantizationStates.INITIAL, see QuantizationStates for more detail.
Expand All @@ -414,10 +414,10 @@ def __init__(
self._quant_min = quant_min
self._quant_max = quant_max
self.observer_algorithm = observer_algorithm
self.inplace = inplace
self.detail = {} if detail is None else detail
self._father_config = self # union-find
self._hash = self.__create_hash()
self._require_export = require_export
super().__init__()

@ abstractmethod
Expand Down Expand Up @@ -509,6 +509,18 @@ def is_revisable(self):
QuantizationStates.FP32
})

@ property
def exportable(self) -> bool:
value_check = isinstance(self.scale, torch.Tensor)
if self._require_export is None:
state_check = QuantizationStates.can_export(self.state)
return (value_check and state_check)
else: return (self._require_export and value_check)

@ exportable.setter
def exportable(self, export_override: bool):
self._require_export = export_override

@ property
def scale(self) -> torch.Tensor:
if self.dominated_by == self: return self._scale
Expand Down Expand Up @@ -650,7 +662,6 @@ def copy(self):
scale=scale, offset=offset,
observer_algorithm=self.observer_algorithm,
detail=self.detail.copy(),
inplace=self.inplace,
state=self.state
)
if self.state == QuantizationStates.OVERLAPPED:
Expand Down
105 changes: 60 additions & 45 deletions ppq/parser/onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ class OnnxExporter(GraphExporter):
def __init__(self) -> None:
super().__init__()

def export_quantization_config(
self, config_path: str, graph: BaseGraph):
def export_quantization_config(self, config_path: str, graph: BaseGraph):
"""Export Tensor Quantization Config to File(Json)."""

render_buffer = {
'configs': {},
Expand Down Expand Up @@ -106,7 +106,59 @@ def export_quantization_config(
with open(file=config_path, mode='w') as file:
json.dump(render_buffer, file, indent=4)

def export_graph(self, graph: BaseGraph) -> onnx.GraphProto:
"""
Convert a PPQ IR to Onnx IR.
This export will only convert PPQ Op and var to onnx, all quantization configs will be skipped.

This function will try to keep the opset version of your graph unchanged.
However if the opset is not given, ppq will convert it to with the global parameter ppq.core.ONNX_EXPORT_OPSET.
"""
name = graph._name
if not name: name = f'{PPQ_CONFIG.NAME} - v({PPQ_CONFIG.VERSION})'

# Ready to export onnx graph defination.
_inputs, _outputs, _initilizers, _nodes = [], [], [], []
for operation in graph.topological_sort():
_nodes.append(self.export_operation(operation))

for variable in graph.variables.values():
tensor_proto = self.export_var(variable)
if variable.name in graph.inputs:
_inputs.append(tensor_proto)
if variable.name in graph.outputs:
_outputs.append(tensor_proto)
if variable.is_parameter:
_initilizers.append(tensor_proto)

graph_def = helper.make_graph(
name=name, nodes=_nodes,
inputs=_inputs, outputs=_outputs,
initializer=_initilizers)

# if opset is missing from your graph, give it a default one.
if GRAPH_OPSET_ATTRIB not in graph._detail:
op = onnx.OperatorSetIdProto()
op.version = ONNX_EXPORT_OPSET
opsets = [op]
else:
opsets = []
for opset in graph._detail[GRAPH_OPSET_ATTRIB]:
op = onnx.OperatorSetIdProto()
op.domain = opset['domain']
op.version = opset['version']
opsets.append(op)

onnx_model = helper.make_model(
graph_def, producer_name=PPQ_CONFIG.NAME, opset_imports=opsets)
onnx_model.ir_version = graph._detail.get('ir_version', ONNX_VERSION)
return onnx_model

def export_operation(self, operation: Operation) -> onnx.OperatorProto:
"""
Convert PPQ Op to Onnx Operation
An Op consumes zero or more Tensors, and produces zero or more Tensors.
"""
if operation.type in OPERATION_EXPORTERS:
exporter = OPERATION_EXPORTERS[operation.type]()
assert isinstance(exporter, OperationExporter), (
Expand All @@ -132,6 +184,11 @@ def export_operation(self, operation: Operation) -> onnx.OperatorProto:
return op_proto

def export_var(self, variable: Variable) -> onnx.TensorProto:
"""
Convert PPQ Variable to Onnx Tensor, There are 2 different types of Tensor in Onnx:
Variable: Represents a Tensor whose value is not known until inference-time.
Constant: Represents a Tensor whose value is known.
"""
if variable.meta is not None:
shape = variable.meta.shape
dtype = variable.meta.dtype.value
Expand Down Expand Up @@ -174,50 +231,8 @@ def export(self, file_path: str, graph: BaseGraph, config_path: str = None):
processor = GraphDeviceSwitcher(graph)
processor.remove_switcher()

name = graph._name
if not name: name = f'{PPQ_CONFIG.NAME} - v({PPQ_CONFIG.VERSION})'

# if a valid config path is given, export quantization config to there.
if config_path is not None:
self.export_quantization_config(config_path, graph)

# Ready to export onnx graph defination.
_inputs, _outputs, _initilizers, _nodes = [], [], [], []
for operation in graph.topological_sort():
_nodes.append(self.export_operation(operation))

for variable in graph.variables.values():
tensor_proto = self.export_var(variable)
if variable.name in graph.inputs:
_inputs.append(tensor_proto)
if variable.name in graph.outputs:
_outputs.append(tensor_proto)
if variable.is_parameter:
_initilizers.append(tensor_proto)

graph_def = helper.make_graph(
name=name,
nodes=_nodes,
inputs=_inputs,
outputs=_outputs,
initializer=_initilizers,
)

# force opset to 11
if GRAPH_OPSET_ATTRIB not in graph._detail:
op = onnx.OperatorSetIdProto()
op.version = ONNX_EXPORT_OPSET
opsets = [op]
else:
opsets = []
for opset in graph._detail[GRAPH_OPSET_ATTRIB]:
op = onnx.OperatorSetIdProto()
op.domain = opset['domain']
op.version = opset['version']
opsets.append(op)

onnx_model = helper.make_model(
graph_def, producer_name=PPQ_CONFIG.NAME, opset_imports=opsets)
onnx_model.ir_version = graph._detail.get('ir_version', ONNX_VERSION)
# onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, file_path)
onnx.save(self.export_graph(graph=graph), file_path)
11 changes: 9 additions & 2 deletions ppq/quantization/algorithm/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,16 @@ def _find_multi_input_ep(op: Operation):

def _find_coherent_ep(op: Operation):
# 如果当前节点后继节点只有一个,向下寻找直系节点
# 但如果直系后继节点有多于一个输入,算法立即停机
ops = self.graph.get_downstream_operations(op)
if len(ops) == 1 and len(self.graph.get_upstream_operations(ops[0])) == 1:
return ops[0]
if len(ops) == 1:
following_op = ops[0]
# PATCH 20220811,get_upstream_operations 不足以判断算子是否只有一个输入
# 因为算子可以直接与图的 input 相连...
non_parameter_input = following_op.num_of_input - following_op.num_of_parameter
upstream_ops = len(self.graph.get_upstream_operations(following_op))
if non_parameter_input == 1 and upstream_ops == 1:
return ops[0]
return None

sp, ep, future_ep = op, op, op
Expand Down
21 changes: 14 additions & 7 deletions ppq/quantization/optim/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ def check_state(state: QuantizationStates):
if op.num_of_input == 3:
i_cfg, w_cfg, b_cfg = op.config.input_quantization_config

if not check_state(w_cfg.state):
raise PermissionError(f'Can not quantize bias of layer {op.name}, '
'cause weight has not been correctly quantized.')
if not check_state(i_cfg.state):
raise PermissionError(f'Can not quantize bias of layer {op.name}, '
'cause input has not been correctly quantized.')

# PATCH 2022.07.29 有的时候 bias 是个多维的东西,此时要求前面的维度都是1
bias = op.inputs[-1].value
assert bias.numel() == bias.shape[-1], (
Expand All @@ -63,13 +56,27 @@ def check_state(state: QuantizationStates):

# 在两种情况下可以执行后续逻辑,1 状态为 PASSIVE_INIT,2 要求 override
if self._override and (b_cfg.state == QuantizationStates.PASSIVE):
if not check_state(w_cfg.state):
raise PermissionError(f'Can not quantize bias of layer {op.name}, '
'cause weight has not been correctly quantized.')
if not check_state(i_cfg.state):
raise PermissionError(f'Can not quantize bias of layer {op.name}, '
'cause input has not been correctly quantized.')

b_cfg.scale = w_cfg.scale * i_cfg.scale * self.scale_multiplier
b_cfg.state = QuantizationStates.PASSIVE
b_cfg.offset = torch.zeros_like(b_cfg.scale)
assert not b_cfg.policy.has_property(QuantizationProperty.ASYMMETRICAL), (
'Passive parameter does not support ASYMMETRICAL quantization')

if b_cfg.state == QuantizationStates.PASSIVE_INIT:
if not check_state(w_cfg.state):
raise PermissionError(f'Can not quantize bias of layer {op.name}, '
'cause weight has not been correctly quantized.')
if not check_state(i_cfg.state):
raise PermissionError(f'Can not quantize bias of layer {op.name}, '
'cause input has not been correctly quantized.')

b_cfg.scale = w_cfg.scale * i_cfg.scale * self.scale_multiplier
b_cfg.state = QuantizationStates.PASSIVE
b_cfg.offset = torch.zeros_like(b_cfg.scale)
Expand Down
40 changes: 1 addition & 39 deletions ppq/quantization/quantizer/PPLQuantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
OperationQuantizationConfig, QuantizationPolicy,
QuantizationProperty, QuantizationStates, RoundingPolicy,
TargetPlatform)
from ppq.core.quant import TensorQuantizationConfig
from ppq.IR import BaseGraph, Operation, Variable
from ppq.IR import BaseGraph, Operation

from .base import BaseQuantizer

Expand Down Expand Up @@ -119,40 +118,3 @@ def rounding_policy(self) -> RoundingPolicy:
@ property
def activation_fusion_types(self) -> set:
return {'Relu', 'Clip', 'Sigmoid', 'LeakyRelu'}


class PPLCUDAMixPrecisionQuantizer(PPLCUDAQuantizer):
def __init__(
self, graph: Union[BaseGraph, BaseGraph]
) -> Union[torch.Tensor, list, dict]:
super().__init__(graph=graph)

def init_quantize_config(self, operation: Operation) -> OperationQuantizationConfig:
config = super().init_quantize_config(operation=operation)
if operation.platform == TargetPlatform.PPL_CUDA_INT4:
for cfg, var in zip(config.input_quantization_config, operation.inputs):
assert isinstance(cfg, TensorQuantizationConfig)
assert isinstance(var, Variable)
if cfg.state == QuantizationStates.INITIAL:
cfg.num_of_bits, cfg.quant_max, cfg.quant_min = 4, 7, -8
return config


class PPLCUDA_INT4_Quantizer(PPLCUDAQuantizer):
def __init__(
self, graph: Union[BaseGraph, BaseGraph]
) -> Union[torch.Tensor, list, dict]:
super().__init__(graph=graph)

def __init__(
self, graph: Union[BaseGraph, BaseGraph]
) -> Union[torch.Tensor, list, dict]:

super().__init__(graph=graph)
self._num_of_bits = 4
self._quant_min = - int(pow(2, self._num_of_bits - 1))
self._quant_max = int(pow(2, self._num_of_bits - 1) - 1)

@ property
def target_platform(self) -> TargetPlatform:
return TargetPlatform.PPL_CUDA_INT8
3 changes: 1 addition & 2 deletions ppq/quantization/quantizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from .MyQuantizer import ExtQuantizer
from .NXPQuantizer import NXP_Quantizer
from .ORTQuantizer import ORT_PerChannelQuantizer, ORT_PerTensorQuantizer
from .PPLQuantizer import (PPLCUDA_INT4_Quantizer,
PPLCUDAMixPrecisionQuantizer, PPLCUDAQuantizer)
from .PPLQuantizer import PPLCUDAQuantizer
from .TRTQuantizer import TensorRTQuantizer
from .FPGAQuantizer import FPGAQuantizer
from .NCNNQuantizer import NCNNQuantizer
Expand Down
1 change: 0 additions & 1 deletion ppq/quantization/quantizer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def target_platform(self) -> TargetPlatform:
def default_platform(self) -> TargetPlatform:
return TargetPlatform.FP32

@ abstractproperty
@ property
def quantize_policy(self) -> QuantizationPolicy:
raise NotImplementedError('Quantizier does not have a default quantization policy yet.')
Expand Down
Loading