Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support onnx #12

Closed
phamkhactu opened this issue Aug 2, 2022 · 61 comments
Closed

Support onnx #12

phamkhactu opened this issue Aug 2, 2022 · 61 comments
Labels
documentation Improvements or additions to documentation

Comments

@phamkhactu
Copy link

Does the model can be converted to onnx model?

@baudm
Copy link
Owner

baudm commented Aug 2, 2022

Tried it just now. Was able to export to ONNX using torch.onnx.export(parseq, dummy_input, 'parseq.onnx', opset_version=14). Not really familiar yet with ONNX so I can't verify if the exported model works as expected (an exported TorchScript model works though, if that matters).

@phamkhactu
Copy link
Author

phamkhactu commented Aug 3, 2022

Tried it just now. Was able to export to ONNX using torch.onnx.export(parseq, dummy_input, 'parseq.onnx', opset_version=14). Not really familiar yet with ONNX so I can't verify if the exported model works as expected (an exported TorchScript model works though, if that matters).

@baudm I can not convert to onnx, the main problem comes from load_from_checkpoint func. The model must load from architecture, and it can be converted to onnx. The func load model from hubconf

def _load_torch_model(checkpoint_path, checkpoint, **kwargs):
    import hubconf
    name = os.path.basename(checkpoint_path).split('-')[0]
    model_factory = getattr(hubconf, name)
    model = model_factory(**kwargs)
    model.load_state_dict(checkpoint)
    return model


def load_from_checkpoint(checkpoint_path: str, **kwargs):
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    try:
        model = _load_pl_checkpoint(checkpoint, **kwargs)
    except KeyError:
        model = _load_torch_model(checkpoint_path, checkpoint, **kwargs)
    return model

can you share code ex for loading checkpoint to Model architecture??
if i can convert to onnx, I will public to test how does it work, or wok expected.

@baudm
Copy link
Owner

baudm commented Aug 3, 2022

import torch

parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
dummy_input = torch.rand(1, 3, *parseq.hparams.img_size)  # (1, 3, 32, 128) by default

# To ONNX
parseq.to_onnx('parseq.onnx', dummy_input, opset_version=14)  # opset v14 or newer is required

# To TorchScript
parseq.to_torchscript('parseq-ts.pt')

@baudm baudm added the documentation Improvements or additions to documentation label Aug 8, 2022
@baudm baudm closed this as completed Aug 8, 2022
@phamkhactu
Copy link
Author

@baudm model converted successfully to onnx, but can not load onnx model. I am asking the expert pytorch to resolve. If done I will give the final onnx

@phamkhactu
Copy link
Author

phamkhactu commented Aug 16, 2022

@baudm after some days, I had try to fix onnx, but can not. I very happy if you can give some line code example for infer model(torchscript), which you converted.

# To TorchScript
parseq.to_torchscript('parseq-ts.pt')

I get error:

model = torch.jit.load("parseq-ts.pt")
  File "anaconda3/envs/dl/lib/python3.6/site-packages/torch/jit/_serialization.py", line 161, in load
    cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
RuntimeError: 
Unknown builtin op: aten::reflection_pad3d.
Here are some suggestions: 
        aten::reflection_pad1d
        aten::reflection_pad2d

The original call is:
  File "/home/tupk/anaconda3/envs/ocr/lib/python3.8/site-packages/torch/nn/functional.py", line 4199
        elif len(pad) == 6 and (input.dim() == 4 or input.dim() == 5):
            if mode == "reflect":
                return torch._C._nn.reflection_pad3d(input, pad)
                       ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            elif mode == "replicate":
                return torch._C._nn.replication_pad3d(input, pad)
Serialized   File "code/__torch__/torch/nn/functional.py", line 634
        if _175:
          if torch.eq(mode, "reflect"):
            _180 = torch.reflection_pad3d(input, pad)
                   ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            _179 = _180
          else:
'_pad' is being compiled since it was called from 'multi_head_attention_forward'
  File "/home/tupk/anaconda3/envs/ocr/lib/python3.8/site-packages/torch/nn/functional.py", line 5032
        v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1))
                        ~~~ <--- HERE
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1))
Serialized   File "code/__torch__/torch/nn/functional.py", line 235
    if torch.__isnot__(attn_mask0, None):
      attn_mask6 = unchecked_cast(Tensor, attn_mask0)
      attn_mask7 = __torch__.torch.nn.functional._pad(attn_mask6, [0, 1], "constant", 0., )
      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
      attn_mask5 : Optional[Tensor] = attn_mask7
    else:
'multi_head_attention_forward' is being compiled since it was called from 'MultiheadAttention.forward'
Serialized   File "code/__torch__/torch/nn/modules/activation.py", line 39
    need_weights: bool=True,
    attn_mask: Optional[Tensor]=None) -> Tuple[Tensor, Optional[Tensor]]:
    _1 = __torch__.torch.nn.functional.multi_head_attention_forward
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    _2 = annotate(List[Tensor], [])
    _3 = torch.append(_2, torch.transpose(query, 1, 0))

@cywinski
Copy link

@baudm I have a similar issue with loading the converted ONNX model. I am able to successfully convert the model to ONNX, but when I try to load and check if the model is well-formed I get the error.

import torch
import onnx

# Load PyTorch model
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
dummy_input = torch.rand(1, 3, *parseq.hparams.img_size)

# Convert to ONNX
parseq.to_onnx('pairseq.onnx', dummy_input, opset_version=14)

# Load the ONNX model
onnx_model = onnx.load('pairseq.onnx')

# Check ONNX model
onnx.checker.check_model(onnx_model, full_check=True)


---------------------------------------------------------------------------
InferenceError                            Traceback (most recent call last)
Input In [1], in <cell line: 15>()
     12 onnx_model = onnx.load('pairseq.onnx')
     14 # Check ONNX model
---> 15 onnx.checker.check_model(onnx_model, full_check=True)

File /opt/venv/lib/python3.8/site-packages/onnx/checker.py:108, in check_model(model, full_check)
    106 C.check_model(protobuf_string)
    107 if full_check:
--> 108     onnx.shape_inference.infer_shapes(model, check_type=True, strict_mode=True)

File /opt/venv/lib/python3.8/site-packages/onnx/shape_inference.py:34, in infer_shapes(model, check_type, strict_mode, data_prop)
     32 if isinstance(model, (ModelProto, bytes)):
     33     model_str = model if isinstance(model, bytes) else model.SerializeToString()
---> 34     inferred_model_str = C.infer_shapes(model_str, check_type, strict_mode, data_prop)
     35     return onnx.load_from_string(inferred_model_str)
     36 elif isinstance(model, str):

InferenceError: [ShapeInferenceError] (op_type:CumSum, node name: CumSum_2527): x typestr: T, has unsupported type: tensor(bool)

@baudm baudm reopened this Aug 16, 2022
@WongVi
Copy link

WongVi commented Aug 18, 2022

Waiting for onnx and tensorrt conversion

@mcmingchang
Copy link

export onnx successful

tgt_padding_mask = (((tgt_in == self.eos_id)*2).cumsum(-1) > 0) # mask tokens beyond the first EOS token.

@RickyGunawan09
Copy link

RickyGunawan09 commented Sep 9, 2022

@mcmingchang can you elaborate more? I can convert the model to onnx but I can't use it with onnxruntime. when building the onnx model I get the following message:

C:\Users\1000\.conda\envs\parseg\lib\site-packages\torch\__init__.py:833: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! assert condition, message

C:\Users\1000\.conda\envs\parseg\lib\site-packages\timm\models\vision_transformer.py:201: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of PyTorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

D:\parseg\parseq\strhub\models\parseq\system.py:129: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if testing and (tgt_in == self.eos_id).any(dim=-1).all():

did someone already success build and run the onnx model ? can someone share it?
thanks

@ashishpapanai
Copy link

I am having similar issues with the ONNX, any leads on it?

@phamkhactu
Copy link
Author

@ashishpapanai maybe waiting expert to solve.

@baudm
Copy link
Owner

baudm commented Oct 4, 2022

tgt_padding_mask = ((tgt_in == self.eos_id).cumsum(-1) > 0) # mask tokens beyond the first EOS token.

This is the offending code fragment. You can comment this out (disabling the iterative refinement branch of the code) before exporting the ONNX model. I tried it and onnx.checker.check_model(parseq, full_check=True) succeeded.

@RickyGunawan09
Copy link

This is the offending code fragment. You can comment this out (disabling the iterative refinement branch of the code) before exporting the ONNX model. I tried it and onnx.checker.check_model(parseq, full_check=True) succeeded.

Thank you for answering @baudm
I'll try it and I'll report back when everything works fine or if there are other issues.

@allenwu5
Copy link

allenwu5 commented Oct 5, 2022

This is the offending code fragment. You can comment this out (disabling the iterative refinement branch of the code) before exporting the ONNX model. I tried it and onnx.checker.check_model(parseq, full_check=True) succeeded.

Thank you @baudm

I tried below code with parseq.refine_iters=0 , and no onnx::CumSum_3090 related errors now.

fp32_onnx_path = "parseq_tiny_fp32.onnx"
parseq.refine_iters=0
parseq.to_onnx(fp32_onnx_path, img, opset_version=14)

int8_onnx_path = "parseq_tiny_uint8.onnx"

from onnxruntime.quantization import (QuantType, quantize_dynamic)
quantize_dynamic(
    model_input=fp32_onnx_path,
    model_output=int8_onnx_path,
    weight_type=QuantType.QUInt8
)

@baudm
Copy link
Owner

baudm commented Oct 5, 2022

Thank you @baudm

I tried below code with parseq.refine_iters=0 , and no onnx::CumSum_3090 related errors now.

@allenwu5 oh yeah, this is even better. Setting refine_iters to 0 will make the iterative refinement branch unreachable, achieving the same effect. Will close this now and update the documentation.

@baudm baudm closed this as completed Oct 5, 2022
@baudm
Copy link
Owner

baudm commented Oct 5, 2022

UPDATE: As of commit ed3d847, refine_iters=0 is no longer required when exporting to ONNX.

In summary, set refine_iters=0 when exporting to ONNX:

import torch

parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
dummy_input = torch.rand(1, 3, *parseq.hparams.img_size)  # (1, 3, 32, 128) by default

# To ONNX
parseq.to_onnx('parseq.onnx', dummy_input, opset_version=14)  # opset v14 or newer is required

@huyhoang17
Copy link

huyhoang17 commented Oct 6, 2022

@baudm
Hi, I tried the above solution and can convert it to the onnx-model successfully.
But had a problem, the output size of the onnx-model was changed. For example, the max_label_length of the base model is 25, but the output size of the onnx model was only 7

parseq_issue_1

The output size was changed with different conversions
How to fix this issue?
Thanks in advance

@ashishpapanai
Copy link

I am facing a similar issue with the output shape.
cc: @baudm

@baudm
Copy link
Owner

baudm commented Oct 6, 2022

max_label_length is exactly that. Autoregressive decoding will terminate once [E] (EOS) token is generated, which means the output sequence length will be less than the maximum supported label length. If you want to have a constant sequence length, use NAR decoding (decode_ar=False)

@huyhoang17
Copy link

huyhoang17 commented Oct 6, 2022

@baudm thank you! It's work

I can also convert the onnx-model to tensorrt format and archive the same result.
For anyone who wants to convert to tensorRT format, you should simplify onnx-model using onnx-simplifier, then convert trt-model using trtexec tool.

The benchmark of inference time between torch, onnx-runtime and trt-model (3x32x128, bs=1, average 100 samples)

torch onnx-runtime tensorrt-fp32 tensorrt-fp16
0.017518 (4.1839x) 0.015875 (3.7915x) 0.004187 (1x) 0.002519

The trt-fp32-model is 4-times faster than the torch model. The trt model was served by triton-inference-server

@dietermaes
Copy link

@baudm Thanks for the advice, the export to onnx worked now.

@huyhoang17 I'm also running the model on a triton server and I'm able to make the inference request which returns me a result that I convert back with the triton client as_numpy function, this gives me an array of [1, 7, 95]. Do you have any advice on how to extract the label and confidence scores from this array?

@huyhoang17
Copy link

huyhoang17 commented Oct 6, 2022

@ashishpapanai
Copy link

@huyhoang17 How did you make the output dimensions equal to [1, max_label_length, 95]?
@baudm I tried turning decode_ar=False, but the output dimension now is [1, 6, 95]; it would be helpful if I could make the output shape 1, 25, 95 and then print the recognised characters in postprocessing.

@huyhoang17
Copy link

@ashishpapanai here is the example code, you should use both 2 params: decode_ar=False & refine_iters=0

Lib version

torch==1.12.1
from strhub.models.utils import load_from_checkpoint

# To ONNX
device = "cuda"
ckpt_path = "..."
onnx_path = "..."
img = ...

parseq = load_from_checkpoint(ckpt_path)
parseq.refine_iters = 0
parseq.decode_ar = False
parseq = parseq.to(device).eval()

parseq.to_onnx(onnx_path, img, do_constant_folding=True, opset_version=14)  # opset v14 or newer is required

# check
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model, full_check=True) ==> pass

@jturner116
Copy link

@huyhoang17 I would love to see your code for ONNX inference, I am very interested in and impressed by your speedtesting!

@ashishpapanai
Copy link

I am getting OpenVINO IR model compile time with AR decoding enabled as 81 Minutes, which is way too large. Is there anything in the community knowledge which I can do to optimize the model?

@phamkhactu
Copy link
Author

@phamkhactu Shoot, that performance degradation at the end is exactly what I noticed too. I thought the conversion to ONNX was supposed to be a bit more stable, I wanted to blame it on my PyTorch version haha. I'm on Torch 1.10.2, were you on 1.12?

yes I must reinstall torch 1.12, with torch 1.10 I convert to onnx successfully, but infer get core dump =((

@cywinski
Copy link

Have someone tried converting any of the available models to CoreML?
I was able to convert the pairseq_tiny model to TorchScript, then to CoreML with the following piece of code:

import torch
import coremltools as ct

parseq = torch.hub.load('baudm/parseq', 'parseq_tiny', pretrained=True, refine_iters=0).eval()
dummy_input = torch.rand(1, 3, *parseq.hparams.img_size)  # (1, 3, 32, 128) by default

# Convert to TorchScript
torch_jit = parseq.to_torchscript(file_path='torchscript_model.pt', method='trace', example_inputs=dummy_input)

# Convert to CoreML
coreml_model = ct.convert(
        model=torch_jit,
        inputs=[
            ct.TensorType(name="input", shape=(1, 3, ct.RangeDim(), ct.RangeDim()))
        ],
        outputs=[
            ct.TensorType(name="output")
        ],
        convert_to="neuralnetwork",
)

coreml_model.save('parseq_tiny.mlmodel')

But then when I try to load the model in the XCode14 I get the following error:
image

@DenghuiXiao
Copy link

unity knowledge which I can do to optimize the mode @huyhoang17 I follow your step and infer on onnx successfully But can't convert onnx file into trt enigine with trtexec(./trtexec --onnx=/parseq/parseq_sim.onnx --saveEngine=/parseq/parseq.enigin). Error:Could not open file ~/parseq/parseq_sim.onnx
tensorrt version: 7.2.3.4
torch 1.12.0
onnx opt 14

@tp-nan
Copy link

tp-nan commented Dec 6, 2022

Is it possible to train and convert a model with dynamic shapes? (image of shape -1 c -1 -1)?

@rafaelagrc
Copy link

rafaelagrc commented Dec 13, 2022

In summary, set refine_iters=0 when exporting to ONNX:

import torch

parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True, refine_iters=0).eval()
dummy_input = torch.rand(1, 3, *parseq.hparams.img_size)  # (1, 3, 32, 128) by default

# To ONNX
parseq.to_onnx('parseq.onnx', dummy_input, opset_version=14)  # opset v14 or newer is required

Hello.
I ran this exact fragment of code, with refine_iters=0 and I got the following warnings:

/home/rafaela.carvalho/anaconda3/envs/parseq/lib/python3.9/site-packages/torch/__init__.py:676: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! assert condition, message

/home/rafaela.carvalho/anaconda3/envs/parseq/lib/python3.9/site-packages/timm/models/vision_transformer.py:217: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

/home/rafaela.carvalho/.cache/torch/hub/baudm_parseq_main/strhub/models/parseq/system.py:129: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if testing and (tgt_in == self.eos_id).any(dim=-1).all():

/home/rafaela.carvalho/anaconda3/envs/parseq/lib/python3.9/site-packages/torch/onnx/symbolic_helper.py:716: UserWarning: allowzero=0 by default. In order to honor zero value in shape use allowzero=1 warnings.warn("allowzero=0 by default. In order to honor zero value in shape use allowzero=1")

/home/rafaela.carvalho/anaconda3/envs/parseq/lib/python3.9/site-packages/torch/onnx/symbolic_helper.py:325: UserWarning: Type cannot be inferred, which might cause exported graph to produce incorrect results. warnings.warn("Type cannot be inferred, which might cause exported graph to produce incorrect results.")

I am using torch==1.10.2 and torchvision==0.11.3

Has anyone ran into this issue or does anyone know how to solve it?

@AbhiHegde3004
Copy link

AbhiHegde3004 commented Mar 15, 2023

Hi @huyhoang17 @baudm

I have converted pytorch model to onnx -> then to trt engine file.

I did the inference on onnx it works fine.
but when i did the inference on trt python it always gives the same output.

I used the same pre and post processing methods for onnx & tensorrt script.

version : TensorRT-8.5.3.1_cuda11

@kino0924
Copy link

Hi @huyhoang17

I am having some issue with onnx -> trt
TensorRT-8.6.0.12.Windows10.x86_64.cuda-11.8

trtexec --onnx=poc_sim.onnx --saveEngine=poc.engine --workspace=4096 --verbose

converts fine but when I load in triton, I get this error
Error Code 1: Serialization (Serialization assertion magicTagRead == kMAGIC_TAG failed.Magic tag does not match)

@phamkhactu
Copy link
Author

Hi @huyhoang17

I am having some issue with onnx -> trt TensorRT-8.6.0.12.Windows10.x86_64.cuda-11.8

trtexec --onnx=poc_sim.onnx --saveEngine=poc.engine --workspace=4096 --verbose

converts fine but when I load in triton, I get this error Error Code 1: Serialization (Serialization assertion magicTagRead == kMAGIC_TAG failed.Magic tag does not match)

U should use 8.5.3.1.

@kino0924
Copy link

Now, my issue is with triton.
Model load fine in triton but when I make request, I get such error

E0420 08:10:14.327730 1 logging.cc:43] 1: [convBaseRunner.cpp::execute::271] Error Code 1: Cask (Cask convolution execution)

triton : nvcr.io/nvidia/tritonserver:23.03-py3

@kino0924
Copy link

kino0924 commented Apr 20, 2023

I went through lots of try and error due to having different training server and inferencing server.
Here is simple instruction for people like me who struggles on converting it to tensorrt.

  1. Simplify onnx with onnx-simplifier
  2. Inference simplified onnx with Triton docker
  3. sh into docker
    sudo docker exec -it {docker_container} sh
  4. convert it by using trtexec
    /usr/src/tensorrt/bin/trtexec --onnx=/models/{your_model}/1/model.onnx --saveEngine=model.trt
  5. copy converted model back to host
    docker cp {docker_container}:/root/model.trt model.trt

Now you have tensorrt model that will work on the machine that runs triton server.

@jturner116
Copy link

@kino0924 I have different training and inferencing servers, thanks for the shoutout on doing the tensorRT conversion in the triton container, that is a good idea. Have you noticed that TensorRT inference is much faster than ONNX inference? I haven't benchmarked, I thought they would be close

@kino0924
Copy link

@jturner116 I did not get dramatic improvement as @huyhoang17 but definitely it was worth it.
I have custom img size and charset.
With onnx I was getting 250 infer/sec and now I am getting 400 infer/sec with tensorrt model
I measured with perf_analyzer

@phamkhactu
Copy link
Author

phamkhactu commented Apr 21, 2023

@jturner116 I did not get dramatic improvement as @huyhoang17 but definitely it was worth it. I have custom img size and charset. With onnx I was getting 250 infer/sec and now I am getting 400 infer/sec with tensorrt model I measured with perf_analyzer

@kino0924

I've converted model to Tensorrt(8.5.3.3.1) successfully, I get only inference 0.001s(not include load model, only infer image)

I am happy if you can share: "what is the difference between trition server and tensorrt8.5.3.1 docker?". I don not use trition server, i only use docker installed tensorrt8.5.3.1.
Thank you.

@kino0924
Copy link

kino0924 commented Apr 27, 2023

@jturner116 I did not get dramatic improvement as @huyhoang17 but definitely it was worth it. I have custom img size and charset. With onnx I was getting 250 infer/sec and now I am getting 400 infer/sec with tensorrt model I measured with perf_analyzer

@kino0924

I've converted model to Tensorrt(8.5.3.3.1) successfully, I get only inference 0.001s(not include load model, only infer image)

I am happy if you can share: "what is the difference between trition server and tensorrt8.5.3.1 docker?". I don not use trition server, i only use docker installed tensorrt8.5.3.1. Thank you.

Triton server is designed for Inferencing.
It supports many features such as multi-gpu, multi-model, batching, and etc
You will find lots of benefit by using Triton server.

@RickyGunawan09
Copy link

@phamkhactu Shoot, that performance degradation at the end is exactly what I noticed too. I thought the conversion to ONNX was supposed to be a bit more stable, I wanted to blame it on my PyTorch version haha. I'm on Torch 1.10.2, were you on 1.12?

how you handle the performance degradation on onnx or TensorRT @jturner116 @phamkhactu ? because I have also encounter this issue and no one talk about it.

@jturner116
Copy link

@RickyGunawan09 I found a kind of hacky solution mentioned here #66 . If I give an example input with my max character length (25 I think) in the onnx export instead of a random tensor, I don't notice the performance degradation. If I were smarter I might be able to figure out why that works, but maybe it makes some sense with EOS token

@phamkhactu
Copy link
Author

@RickyGunawan09 I found a kind of hacky solution mentioned here #66 . If I give an example input with my max character length (25 I think) in the onnx export instead of a random tensor, I don't notice the performance degradation. If I were smarter I might be able to figure out why that works, but maybe it makes some sense with EOS token

I've just tested again, performance degradation on onnx will appear if bad image input. Have you ever tested with bad image ?? @jturner116

@jturner116
Copy link

jturner116 commented May 24, 2023

@phamkhactu right, often people use random tensors in ONNX exports, but I used this image
onnx_test
With random tensors/bad images, the onnx model would sometimes predict correctly for words shorter than my bad image, but using words like the above seems much more stable

EDIT: Sorry, just now realized you probably meant bad image input to the ONNX model. I will test this again too, thanks for the heads up

@jturner116
Copy link

@phamkhactu Tested random tensors to ONNX and to original model and outputs are fine with

np.testing.assert_allclose(to_numpy(torch_pred), to_numpy(onnx_pred), rtol=1e-03, atol=1e-03) 

Very acceptable for my case

@Gavinic
Copy link

Gavinic commented Jun 6, 2023

@baudm @phamkhactu hi, I have converted my parseq model successfully with the 'decoder_ar=false, refine_iters=2'. The converted onnx model can't get stable predictions. Sometimes, redundant repeated characters will be generated. For example:
image
just like the question: #12 (comment)
is there any solutions? thanks!

@phamkhactu
Copy link
Author

refine_iters=2

@Gavinic you can check again if you use refine_iters=2, which can change the length of output(original output model is: 26x95).
You can try set refine_iters=0

@keivanmoazami
Copy link

I convert model to onnx successfully but trt engine result is very bad. full description is in this link : NVIDIA/TensorRT#3136

@kino0924
Copy link

kino0924 commented Sep 30, 2023

Anyone had problem with converting onnx to TensorRT 8.6.1?
When I convert onnx model to trt with 8.6.1, I get blank inference result.

@keivanmoazami
I think I am having same issue as you.
How did you added cast layers to 11 encoder blocks?
I am only doing onnx to tensorrt fp32 conversion and its failing hard.

@keivanmoazami
Copy link

@kino0924 use onnx modifier and add cast layers like image explained.
https://github.com/ZhangGe6/onnx-modifier

@xlg-go
Copy link

xlg-go commented Jan 5, 2024

abinet convert to onnx, onnx.checker.check_model(onnx_model, full_check=True)

onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:CumSum, node name: /vision/CumSum): x typestr: T, has unsupported type: tensor(bool)

image

@z3lz
Copy link

z3lz commented Jan 9, 2024

I have already converted the pre trained model parameters parseq_tiny to onnx, but I won't be able to extract the post-processing. Is there any Python code for onnx inference @baudm

@suhas004
Copy link

suhas004 commented Feb 8, 2024

@z3lz you can use this script for postprocessing after converting into ONNX

logits = torch.from_numpy(ort_outs[0])
outputs = logits.softmax(-1)

token_decoder = TokenDecoder()
pred, conf_scores = token_decoder.decode(outputs)

class TokenDecoder:
    def __init__(self):
        self.specials_first = ('[E]',)
        self.specials_last = ('[B]', '[P]')
        self.charset = (
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
            'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
            'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '!', '"', '#', '$', '%', '&', "'", '(', ')',
            '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' '
        )
        self.itos = self.specials_first + self.charset + self.specials_last
        self.stoi = {s: i for i, s in enumerate(self.itos)}

    def ids2tok(self, token_ids: List[int], join: bool = True) -> str:
        tokens = [self.itos[i] for i in token_ids]
        return ''.join(tokens) if join else tokens

    def filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
        ids = ids.tolist()
        eos_id, bos_id, pad_id = [self.stoi[s] for s in self.specials_first + self.specials_last]

        try:
            eos_idx = ids.index(eos_id)
        except ValueError:
            eos_idx = len(ids)  # Nothing to truncate.
        ids = ids[:eos_idx]
        probs = probs[:eos_idx + 1]
        return probs, ids

    def decode(self, token_dists: Tensor, raw: bool = False):
        batch_tokens = []
        batch_probs = []
        for dist in token_dists:
            probs, ids = dist.max(-1)  # greedy selection
            if not raw:
                probs, ids = self.filter(probs, ids)
            tokens = self.ids2tok(ids)
            batch_tokens.append(tokens)
            batch_probs.append(probs)
        return batch_tokens, batch_probs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests