Skip to content

Commit

Permalink
Fix Executing Bugs (#408)
Browse files Browse the repository at this point in the history
* Fix Executing Bugs

* 修复了 unsqueeze 算子在多于一个轴时的顺序错误问题
* 修复了 softmax 算子在 opset 11 时默认轴错误的问题
* 修复了 图拷贝 过程中可能因为空值而出现的错误
  • Loading branch information
ZhangZhiPku committed Mar 13, 2023
1 parent 845086f commit a583906
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
7 changes: 4 additions & 3 deletions ppq/IR/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def source_op_platform(self) -> TargetPlatform:

def copy(self, copy_value: bool = False):
clone = QuantableVariable(super().copy(copy_value))
if copy_value: clone._fp32_value = self._fp32_value.clone()
if copy_value and self._fp32_value is not None:
clone._fp32_value = self._fp32_value.clone()
else: clone._fp32_value = self._fp32_value
return clone

Expand Down Expand Up @@ -314,10 +315,10 @@ def dequantize_graph(self, expire_device: str = 'cpu'):
"""一个方便懒人的函数."""
for operation in self.graph.operations.values():
if isinstance(operation, QuantableOperation):
operation.dequantize()
operation.dequantize(expire_device=expire_device)

def restore_quantize_state(self, expire_device: str = 'cpu'):
"""一个方便懒人的函数."""
for operation in self.graph.operations.values():
if isinstance(operation, QuantableOperation):
operation.restore_quantize_state()
operation.restore_quantize_state(expire_device=expire_device)
6 changes: 4 additions & 2 deletions ppq/executor/op/torch/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ def Unsqueeze_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacke
axes = GET_ATTRIBUTE_FROM_OPERATION(op=op, attribute='axes', compulsive=True)

if isinstance(axes, list):
for squeezing_dim in sorted(axes, reverse=True):
for squeezing_dim in sorted(axes):
unsqueezing_tensor = torch.unsqueeze(unsqueezing_tensor, squeezing_dim)
elif isinstance(axes, int):
unsqueezing_tensor = torch.unsqueeze(unsqueezing_tensor, axes)
Expand Down Expand Up @@ -2113,9 +2113,11 @@ def Softmax_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackend
Returns:
torch.Tensor: [description]
"""
if op.opset.onnx_opset_version() >= 13: default_axis = -1
else: default_axis = 1
ASSERT_NUM_OF_INPUT(op=op, values=values, min_num_of_input=1, max_num_of_input=1)
[input] = values
axis = GET_ATTRIBUTE_FROM_OPERATION(op=op, attribute='axis', default=-1)
axis = GET_ATTRIBUTE_FROM_OPERATION(op=op, attribute='axis', default=default_axis)
output = F.softmax(input, axis)
return output

Expand Down

0 comments on commit a583906

Please sign in to comment.