From a583906c9a4ac2c799b7bf4acabec60750bb84e8 Mon Sep 17 00:00:00 2001 From: AwesomeCodingBoy <43309460+ZhangZhiPku@users.noreply.github.com> Date: Mon, 13 Mar 2023 21:16:20 +0800 Subject: [PATCH] Fix Executing Bugs (#408) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix Executing Bugs * 修复了 unsqueeze 算子在多于一个轴时的顺序错误问题 * 修复了 softmax 算子在 opset 11 时默认轴错误的问题 * 修复了 图拷贝 过程中可能因为空值而出现的错误 --- ppq/IR/quantize.py | 7 ++++--- ppq/executor/op/torch/default.py | 6 ++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/ppq/IR/quantize.py b/ppq/IR/quantize.py index fd62fa60..d5398fda 100644 --- a/ppq/IR/quantize.py +++ b/ppq/IR/quantize.py @@ -237,7 +237,8 @@ def source_op_platform(self) -> TargetPlatform: def copy(self, copy_value: bool = False): clone = QuantableVariable(super().copy(copy_value)) - if copy_value: clone._fp32_value = self._fp32_value.clone() + if copy_value and self._fp32_value is not None: + clone._fp32_value = self._fp32_value.clone() else: clone._fp32_value = self._fp32_value return clone @@ -314,10 +315,10 @@ def dequantize_graph(self, expire_device: str = 'cpu'): """一个方便懒人的函数.""" for operation in self.graph.operations.values(): if isinstance(operation, QuantableOperation): - operation.dequantize() + operation.dequantize(expire_device=expire_device) def restore_quantize_state(self, expire_device: str = 'cpu'): """一个方便懒人的函数.""" for operation in self.graph.operations.values(): if isinstance(operation, QuantableOperation): - operation.restore_quantize_state() + operation.restore_quantize_state(expire_device=expire_device) diff --git a/ppq/executor/op/torch/default.py b/ppq/executor/op/torch/default.py index d943c698..35017bc5 100644 --- a/ppq/executor/op/torch/default.py +++ b/ppq/executor/op/torch/default.py @@ -1088,7 +1088,7 @@ def Unsqueeze_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacke axes = GET_ATTRIBUTE_FROM_OPERATION(op=op, attribute='axes', compulsive=True) if isinstance(axes, list): - for squeezing_dim in sorted(axes, reverse=True): + for squeezing_dim in sorted(axes): unsqueezing_tensor = torch.unsqueeze(unsqueezing_tensor, squeezing_dim) elif isinstance(axes, int): unsqueezing_tensor = torch.unsqueeze(unsqueezing_tensor, axes) @@ -2113,9 +2113,11 @@ def Softmax_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackend Returns: torch.Tensor: [description] """ + if op.opset.onnx_opset_version() >= 13: default_axis = -1 + else: default_axis = 1 ASSERT_NUM_OF_INPUT(op=op, values=values, min_num_of_input=1, max_num_of_input=1) [input] = values - axis = GET_ATTRIBUTE_FROM_OPERATION(op=op, attribute='axis', default=-1) + axis = GET_ATTRIBUTE_FROM_OPERATION(op=op, attribute='axis', default=default_axis) output = F.softmax(input, axis) return output