Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX 模型输入时,reshape / flatten 等涉及维度变换的算子,在维度变换参数固定时,batch 输入的 calibration pass 容易扑街 #161

Closed
wusongchao opened this issue Jun 13, 2022 · 1 comment

Comments

@wusongchao
Copy link

这可能是一个比较 general 的问题,即有一部分部署输入的 ONNX 模型,可能经过 shape inference 或 simplified 或什么其它的方式,它的 reshape 这类算子输入的 shape 参数维度是固定的。这样就导致 calibration pass 的时候,当输入是 batch,torch executor 就扑街了。

以这个模型为例,Reshape_71 这个算子的 shape 是固定的。
image

不出意外的,当 batch_size = 32 时,12483460*32=6266880 就 reshape 不到 [1,2,48,34,60]。

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/ppq/executor/torch.py", line 359, in __forward
    outputs = operation_forward_func(operation, inputs, self._executing_contenxt)
  File "/usr/local/lib/python3.6/dist-packages/ppq/executor/op/torch/default.py", line 458, in Reshape_forward
    return data.reshape(shape)
RuntimeError: shape '[1, 2, 48, 34, 60]' is invalid for input of size 6266880

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "ppq-entrance.py", line 66, in <module>
    device=DEVICE, verbose=0)
  File "/usr/local/lib/python3.6/dist-packages/ppq/core/defs.py", line 65, in _wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/ppq/api/interface.py", line 267, in quantize_onnx_model
    collate_fn=collate_fn
  File "/usr/local/lib/python3.6/dist-packages/ppq/core/defs.py", line 65, in _wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/ppq/quantization/quantizer/base.py", line 74, in quantize
    **kwargs
  File "/usr/local/lib/python3.6/dist-packages/ppq/quantization/optim/base.py", line 95, in optimize
    optimization_pass.apply(processer=processer, dataloader=dataloader, executor=executor, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/ppq/quantization/optim/base.py", line 30, in apply
    self.optimize(processer, dataloader=dataloader, executor=executor, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/ppq/core/defs.py", line 65, in _wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/ppq/quantization/optim/calibration.py", line 117, in optimize
    executor=executor, hooks=hooks, output_names=None)
  File "/usr/local/lib/python3.6/dist-packages/ppq/quantization/optim/calibration.py", line 59, in calibrate
    output_names=output_names)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/ppq/executor/torch.py", line 231, in forward
    hooks=hooks
  File "/usr/local/lib/python3.6/dist-packages/ppq/executor/torch.py", line 387, in __forward
    raise RuntimeError(f'Error happens when dealing with operation {str(operation)}')
RuntimeError: Error happens when dealing with operation Reshape_71(TargetPlatform.FP32) - inputs:['708', '1894'], outputs:['720']
@ZhangZhiPku
Copy link
Collaborator

很遗憾,这个问题似乎并不能解决,因为你的模型里面如果Reshape写死了shape=[1,2,48,34,60],那语义上他就不能跑batchsize=16的case;一个弥补的方案是你可以改一下reshape forward的里面的写法,比如强制把第一维设置成-1来跑通...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants