ONNX is great, but sometimes too complicated.
One day I wanted to export the following simple reshape operation to ONNX:
import torch class JustReshape(torch.nn.Module): def __init__(self): super(JustReshape, self).__init__() def forward(self, x): return x.view((x.shape, x.shape, x.shape, x.shape)) net = JustReshape() model_name = 'just_reshape.onnx' dummy_input = torch.randn(2, 3, 4, 5) torch.onnx.export(net, dummy_input, model_name, input_names=['input'], output_names=['output'])
Input shape in ONNX is static, so what I expected is
However, I got the following complicated model even after polishing:
Moreover, there are also some operations performed on weights (like this), which can all be eliminated by offline computation.
ONNX Simplifier is presented to simplify the ONNX model. It infers the whole computation graph and then replaces the redundant operators with their constant outputs.
Install it via pip (Python >= 3.5)
pip3 install onnx-simplifier
python3 -m onnxsim input_model output_model
An overall comparison between a complicated model and its simplified version: