Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer is broken for new PyTorch exports and segfault in onnx.checker #2417

Closed
dashesy opened this issue Oct 24, 2019 · 16 comments
Closed
Labels
optimizer Issues related to ONNX optimizers

Comments

@dashesy
Copy link

dashesy commented Oct 24, 2019

Fist export a model (as I did for this issue) in the latest PyTorch

import torch
import torch.nn as nn

class ToFloat(nn.Module):
    def __init__(self):
        """Convert to .float()
        """
        super(ToFloat, self).__init__()

    def forward(self, x):
        return x.float()

@torch.jit.script
def center_slice_helper(x, h_offset, w_offset, h_end, w_end):
    return x[:, :, h_offset:h_end, w_offset:w_end]


class CenterCrop(nn.Module):
    def __init__(self, crop_size, use_jit=False):
        """Crop from the center of a 4d tensor
        Input shape can be dynamic
        :param crop_size: the center crop size
        :param use_jit: if should use the jit helpter
        """
        super(CenterCrop, self).__init__()
        self.crop_size = crop_size
        self.use_jit = use_jit
        self.register_buffer('crop_size_t', torch.tensor(crop_size))

    def extra_repr(self):
        """Extra information
        """
        return 'crop_size={}'.format(
            self.crop_size
        )

    def forward(self, x):
        """
        :type x: torch.Tensor
        """
        height, width = x.shape[2], x.shape[3]
        if not isinstance(height, torch.Tensor):
            height, width = torch.tensor(height).to(x.device), torch.tensor(width).to(x.device)
        h_offset = (height - self.crop_size_t) / 2
        w_offset = (width - self.crop_size_t) / 2
        h_end = h_offset + self.crop_size_t
        w_end = w_offset + self.crop_size_t
        if self.use_jit:
            return center_slice_helper(x, h_offset, w_offset, h_end, w_end)
        return x[:, :, h_offset:h_end, w_offset:w_end]

model = nn.Sequential(ToFloat(), CenterCrop(224, use_jit=True))
onnxfile = "/mnt/output/gr/crop.onnx"

targets = ["cropped"]
dynamic_axes = {'data': [2, 3]}
dummy_input = torch.randn(1, 3, 300, 256, device='cpu').byte()
torch.onnx.export(model, dummy_input, onnxfile,
                  verbose=True, input_names=['data'],
                  dynamic_axes=dynamic_axes,
                  output_names=targets,
                  opset_version=10)

Now try optimzier

import onnx
from onnx import optimizer
onnx_model = onnx.load(onnxfile)
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer"]
optimized_model = optimizer.optimize(onnx_model, passes)
onnx.save(optimized_model, onnxfile)

And you get this cryptic error message:

~/miniconda3/envs/py3/lib/python3.6/site-packages/onnx/optimizer.py in optimize(model, passes, fixed_point)
     53         optimized_model_str = C.optimize_fixedpoint(model_str, passes)
     54     else:
---> 55         optimized_model_str = C.optimize(model_str, passes)
     56
     57     return onnx.load_from_string(optimized_model_str)

IndexError: _Map_base::at

And checker segfaults!

onnx.checker.check_model(onnx_model)
Segmentation fault

Related to issue #1385 but with repro

@dashesy
Copy link
Author

dashesy commented Oct 24, 2019

If I export using keep_initializers_as_inputs=True suggested in #1385 optimzier will not have issue, but checker still segfaults

@daquexian
Copy link
Member

daquexian commented Oct 24, 2019

+1. It is because the current code (https://github.com/onnx/onnx/blob/master/onnx/common/ir_pb_converter.cc#L230-L237) only add graph proto's inputs but not initializers into the map value_by_name_of.

I think it is a very critical bug :| It affects all models exported from PyTorch 1.3.

@corleonechensiyu
Copy link

+1. It is because the current code (https://github.com/onnx/onnx/blob/master/onnx/common/ir_pb_converter.cc#L230-L237) only add graph proto's inputs but not initializers into the map value_by_name_of.

I think it is a very critical bug :| It affects all models exported from PyTorch 1.3.

不止是pytorch版本问题,我用keras2onnx 转换模型,然后用onnxsim优化identity,也是这个问题。。。
optimized_model_str = C.optimize(model_str, passes)
IndexError: _Map_base::at
version:keras 2.3.0 ,onnx1.6.0,tf-gpu 1.14.0 ,keras2onnx source code 446b606

@daquexian
Copy link
Member

It is not so easy to fix. #2247 contains a failed attempt.

@prasanthpul @linkerzhang Is there some plan to fix it from ONNX maintainers?

@corleonechensiyu
Copy link

It is not so easy to fix. #2247 contains a failed attempt.

@prasanthpul @linkerzhang Is there some plan to fix it from ONNX maintainers?

keras 低版本是否可行呢.

@daquexian
Copy link
Member

daquexian commented Oct 28, 2019

It is not so easy to fix. #2247 contains a failed attempt.
@prasanthpul @linkerzhang Is there some plan to fix it from ONNX maintainers?

keras 低版本是否可行呢.

Maybe :) You can have a try.

@skottmckay
Copy link
Contributor

If I export using keep_initializers_as_inputs=True suggested in #1385 optimzier will not have issue, but checker still segfaults

@dashesy If I run checker on the original exported model (without adding keep_initializers_as_inputs=True) it's happy. Is it only after you attempt to run the optimizers that the checker is unhappy with the model?

FWIW onnxruntime has working implementations of the optimizations you're attempting and in the latest version can save an updated model post optimization. Set SessionOptions.optimized_model_filepath before loading the model and it will write the optimized onnx model out to that path. It will check that an initializer is actually const before fusing, and can also search parent graphs for initializers in order to optimize subgraphs in control flow nodes.

It will take a fairly significant overhaul of the IR and optimizer setup in onnx to make the implementations there more correct, so that's maybe your best short term option.

e.g.

import onnxruntime as ort
so = ort.SessionOptions()
so.optimized_model_filepath = 'optimized_crop.onnx'
session = ort.InferenceSession('crop.onnx', so)

import onnx
m = onnx.load('optimized_crop.onnx')
onnx.checker.check_model(m)

@daquexian
Copy link
Member

daquexian commented Oct 31, 2019

onnx-simplifier also supports the case that initializers are not in inputs now :) It generates optimized and clean ONNX models.

@dashesy
Copy link
Author

dashesy commented Nov 1, 2019

@daquexian I actually rely on dynamic_axes a lot, and so do not want to remove the shape nodes unless they refer to static axes in data. Otherwise would like to give onnx-simplifier a go!

@dashesy
Copy link
Author

dashesy commented Nov 4, 2019

@skottmckay I tried

sess_options = rt.SessionOptions()
#sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
sess_options.optimized_model_filepath = "/mnt/output/gr/model_optimzied.onnx"
sess = rt.InferenceSession(onnxfile, sess_options)

The model is twice in size but I still get these errors (which is why I used onnx.optimzier for):

2019-11-04 15:52:51.3438718 [W:onnxruntime:CSharpOnnxRuntime, graph.cc:2367 onnxruntime::Graph::CleanUnusedInitializers] Removing initializer '2.tee.4.layer4_4.2.conv1.weight'. It is not used by any node and should be removed from the model

I used to be able to fix these using onnx optimizer, but now that is broken

@skottmckay
Copy link
Contributor

@dashesy There is a gap in how ORT is handling initializers that become redundant during optimizations. microsoft/onnxruntime#2320 should address that.

@dashesy
Copy link
Author

dashesy commented Nov 7, 2019

@skottmckay I applied that PR as a patch against current master, and it works! I used ORT_ENABLE_EXTENDED, it got rid of unused initializers, also fused Conv+Relu as well as Conv+BN

From
Before

To
After

Which is all I wanted to do.

@skottmckay
Copy link
Contributor

skottmckay commented Nov 7, 2019

I used ORT_ENABLE_EXTENDED, it got rid of unused initializers, also fused Conv+Relu as well as Conv+BN

Excellent. One note in case it's relevant - FusedConv is not an official ONNX operator, so this saved model would only be able to be run by ORT. If you need a model that conforms to the ONNX spec a lower optimization level would need to be used. For this model that would still mean Conv + BN gets fused and unused initializers removed, but the (Conv +BN result) + Relu that is handled by FusedConv would be missing.

@daquexian
Copy link
Member

daquexian commented Nov 16, 2019

If I export using keep_initializers_as_inputs=True suggested in #1385 optimzier will not have issue, but checker still segfaults

@dashesy If I run checker on the original exported model (without adding keep_initializers_as_inputs=True) it's happy. Is it only after you attempt to run the optimizers that the checker is unhappy with the model?

FWIW onnxruntime has working implementations of the optimizations you're attempting and in the latest version can save an updated model post optimization. Set SessionOptions.optimized_model_filepath before loading the model and it will write the optimized onnx model out to that path. It will check that an initializer is actually const before fusing, and can also search parent graphs for initializers in order to optimize subgraphs in control flow nodes.

It will take a fairly significant overhaul of the IR and optimizer setup in onnx to make the implementations there more correct, so that's maybe your best short term option.

e.g.

import onnxruntime as ort
so = ort.SessionOptions()
so.optimized_model_filepath = 'optimized_crop.onnx'
session = ort.InferenceSession('crop.onnx', so)

import onnx
m = onnx.load('optimized_crop.onnx')
onnx.checker.check_model(m)

This issue deserves the highest priority to resolve, as it affects many many models and we cannot make every user or even library (e.g., caffe2 onnx backend, #2458) to use onnxruntime just for optimizing.. :( What's your opinion on it?

@linkerzhang
Copy link
Member

Thank you @daquexian a lot for investigating the issue and getting the root cause!

Sorry that I didn't chime into this topic early.

Yes. This is indeed an issue of ONNX IR (c++) and optimizer (in ONNX repo) now. As the ONNX spec (model format and op spec) keeps moving on, the ONNX IR (c++) and optimizers are not maintained properly. That says, ONNX IR (c++) and optimizers in current repo are not taken as part of ONNX standard repo right now. So two options for us (the full community),

  1. Make the IR (c++) and optimizer as part of standard repo and maintained properly.
  2. IR (c++) and optimizers might be moved out of standard repo and maintained/contributed by community.
    I called out this topic in the workshop in Shanghai and would prefer the 2nd option above.

Meanwhile, I do think that ONNX runtime has a fair optimizer list maintained better, so that you guys may choose. ONNX runtime team will make the optimizer lib more general and easier to be used.

@jcwchen
Copy link
Member

jcwchen commented Apr 15, 2021

Please note that ONNX optimizer has been moved to another repo https://github.com/onnx/optimizer since ONNX 1.9. If you still have questions related to the optimizer, please raise an issue there. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
optimizer Issues related to ONNX optimizers
Projects
None yet
Development

No branches or pull requests

6 participants