Skip to content

Conversation

@kakukakujirori
Copy link

When I convert EdgeNeXt to ONNX, the converted model throws an error at inference time. This occurs when I use the 3rd or 4th features of EdgeNeXt.

I tested the following in linux and mac environment with various python & opset version, and they all returned the same error messages. So this problem seems not environment-dependent, but specific to pytorch->onnx conversion.

import numpy as np
import torch
import torch.nn as nn
import timm
import onnxruntime

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = timm.create_model("edgenext_small", pretrained=True, features_only=True)
    
    def forward(self, x):
        return self.net(x)[3]

# Input to the model
batch_size = 2
x = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
torch_model = MyModel()
torch_out = torch_model(x)

# Export the model
torch.onnx.export(torch_model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "edgenext.onnx",           # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_session = onnxruntime.InferenceSession("edgenext.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
for out_t, out_o in zip(torch_out, ort_outs):
    np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

Error message:

/Users/kakujiro/opt/anaconda3/envs/py310/lib/python3.10/site-packages/timm/models/edgenext.py:157: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
............(omitted)............
WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
2022-07-30 20:14:01.625713 [E:onnxruntime:, sequential_executor.cc:368 Execute] Non-zero status code returned while running Split node. Name:'Split_627' Status Message: Cannot split using values in 'split' attribute. Axis=1 Input shape={2,160,14,14} NumOutputs=3 Num entries in 'split' (must equal number of outputs) was 3 Sum of sizes in 'split' (must equal size of selected axis) was 162
Traceback (most recent call last):
  File "/Users/kakujiro/Desktop/edgenext_check.py", line 39, in <module>
    ort_outs = ort_session.run(None, ort_inputs)
  File "/Users/kakujiro/opt/anaconda3/envs/py310/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Split node. Name:'Split_627' Status Message: Cannot split using values in 'split' attribute. Axis=1 Input shape={2,160,14,14} NumOutputs=3 Num entries in 'split' (must equal number of outputs) was 3 Sum of sizes in 'split' (must equal size of selected axis) was 162

After some probe I found that this error comes from F.normalize in CrossCovarianceAttn, especially within its clip operation by eps, and it was resolved when I wrote a naive normalize function without using torch.nn.functional.

I also modified the line qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) according to the above warning.

I have verified that

  • all the above warnings and errors now go away,
  • its onnx inference works correctly and its output is almost the same as that of the original pytorch model (checked by np.testing.assert_allclose with rtol=1e-03, atol=1e-05 according to this),
  • this change doesn't affect loading the pretrained weights since learnable parameters are untouched,
  • and model outputs doesn't change before and after this modification for all four edgenext_small variants.

The last correspondence was checked using the following (dirty) script. By the way my testing environment is python=3.10.5, torch=1.12.0, onnxruntime=1.12.0.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import onnxruntime

# Implement from parent directory in order to avoid collision of local timm import and that installed from pypl. Here my folked timm repository `pytorch_image_models_forked` exists
from pytorch_image_models_forked.timm.models.edgenext import EdgeNeXt, edgenext_small, edgenext_small_rw, edgenext_x_small, edgenext_xx_small

ori_net = timm.create_model("edgenext_small", pretrained=True, features_only=True)
net = edgenext_small(pretrained=True, features_only=True)
ori_net.eval()
net.eval()

x = torch.randn(2, 3, 224, 224, requires_grad=True)
with torch.no_grad():
    ori_ret = ori_net(x)
    ret = net(x)

for o, r in zip(ori_ret, ret):
    assert torch.allclose(o, r)

I hope this will help others who want to deploy this model in other formats.

@rwightman
Copy link
Collaborator

@kakukakujirori thanks for the PR, I looked at this and didn't like using a custom normalize, it seems it's some sort of interaction between the split and possibly the normalize but I think split is the bigger issue, it works for me if I change the split to a chunk so I have pushed that up now... and did the little shape change to remove one of the warnings (although that one never seems to impact the result)

@kakukakujirori
Copy link
Author

My bad, your correction looks much more smart and gets to the point.
I tested it on my laptop and verified it works fine, so I'll close this PR.
Thank you so much for your review.

@kakukakujirori kakukakujirori deleted the edgenext_onnx_exportable branch August 6, 2022 04:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants