Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorrt 7.0 In function importInstanceNormalization: [8] Assertion failed: !isDynamic(tensor_ptr->getDimensions()) && "InstanceNormalization does not support dynamic inputs!" #374

Closed
lucasjinreal opened this issue Jan 16, 2020 · 18 comments
Labels
enhancement New feature or request triaged Issue has been triaged by maintainers

Comments

@lucasjinreal
Copy link
Contributor

 In function importInstanceNormalization:
[8] Assertion failed: !isDynamic(tensor_ptr->getDimensions()) && "InstanceNormalization does not support dynamic inputs!"

Using tensorrt 7.0 pytorch 1.4 export in opset 11 got above error.

But, wired things is that, previous can be converted when using opset 10,

@arjunbhargava
Copy link

arjunbhargava commented Jan 22, 2020

I'm having a similar issue, but also with Opset 10. GroupNormalization is exported as an InstanceNormalization (which makes sense as a workaround), however TRT7 seems to introduce this new error. For some reason, a single normalization layer has no issue but one we add another the engine fails to export. I've attached a reproducible minimal error case here:

import numpy as np
import onnx
import onnxruntime as rt
import torch
from torch import nn

import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)


class TestModel(nn.Module):
    """Minimal model for reproducing error"""
    def __init__(self):
        super().__init__()
        # Block 1
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.gn = nn.GroupNorm(8, 32, affine=True)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)

        # An additional GN breaks.
        self.gn2 = nn.GroupNorm(2, 32, affine=True)

    def forward(self, x):
        out = self.conv1(x)
        out = self.gn(out)
        out = self.conv2(out)

        out = self.gn2(out) # If we comment out this line, there are no issues
        return out


def export_model(image_width=224, image_height=224):
    """Export test-case ONNX Graph"""

    model_file = "test.onnx"
    unrolled_model = TestModel()
    unrolled_model.eval()

    img = torch.randn(1, 3, image_height, image_width)

    # Run ONNXification
    unrolled_model = unrolled_model.eval()
    torch.onnx.export(
        unrolled_model,
        img,
        model_file,
        verbose=False,
        do_constant_folding=False,
        opset_version=10,
        input_names=['input'],
        output_names=['output']
    )

    # Try to load model and check validity
    onnx_model = onnx.load(model_file)
    onnx.checker.check_model(onnx_model)
    print(onnx.helper.printable_graph(onnx_model.graph))

    # Run unrolled model
    with torch.no_grad():
        outputs = unrolled_model(img)
        outputs = outputs.numpy().squeeze()

    # Run ONNX model
    sess = rt.InferenceSession(model_file)
    input_name = sess.get_inputs()[0].name
    pred_onnx = sess.run(None, {input_name: img.numpy()})
    outputs_onnx = pred_onnx[0].squeeze()

    # Now compare output tensors
    assert np.allclose(outputs, outputs_onnx, atol=1e-03)
    return onnx_model


def GiB(val):
    return val * 1 << 30


def build_engine_onnx(model_file, workspace_GiB=2, max_batch_size=1):
    """Build engine from ONNX file"""
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    ) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:

        builder.max_workspace_size = GiB(workspace_GiB)
        builder.max_batch_size = max_batch_size
        builder.fp16_mode = False

        with open(model_file, 'rb') as model:
            pars_res = parser.parse(model.read())
        if pars_res:
            print("OK ONNX")
            print('Number of errors: {}'.format(parser.num_errors))
        else:
            print('Number of errors: {}'.format(parser.num_errors))
            for err_idx in range(parser.num_errors):
                error = parser.get_error(err_idx)
                print('Description of the error: {}'.format(error.desc()))
                print('Line where the error occurred: {}'.format(error.line()))
                print('Error code: {}'.format(error.code()))
                print("Model was not parsed successfully")
            del parser
            return None

        engine = builder.build_cuda_engine(network)
        print("ENGINE = ", engine)
        print("NETWORK = ", network)
        return engine


def build_engine(onnx_graph):
    # Try to load model and check validity.
    onnx.checker.check_model(onnx_graph)
    sess = rt.InferenceSession("test.onnx")

    print('Inputs:')
    [print(i) for i in sess.get_inputs()]
    print('Outputs:')
    [print(o) for o in sess.get_outputs()]

    # Build TRT engine.
    engine = build_engine_onnx("test.onnx")
    assert engine is not None
    print("DONE")


if __name__ == "__main__":
    build_engine(export_model())

Environment detail:

 TORCHVISION_VERSION=0.5.0
 CUDNN_VERSION=7.6.3.30-1+cuda10.0
 NCCL_VERSION=2.5.6-1+cuda10.0
 TRT_VERSION=7.0.0.11

@lucasjinreal
Copy link
Contributor Author

does there any updates for this?

@TengFeiHan0
Copy link

@jinfagang did you solve this problem? I am trying to speed up FCOS with TensorRT, then I got this kind of error.

@springyoung
Copy link

@TengFeiHan0 same as you.I use onnx2trt to parse my fcos.onnx model, got same error!

@dechunwang
Copy link

dechunwang commented Aug 12, 2020

@jinfagang Any updates? I got the same issue with tensorRT 7.1 , pytorch 1.6
@rajeevsrao Will this to be fixed in the near future?

trtexec --onnx=model.onnx --verbose

[08/12/2020-10:24:43] [V] [TRT] ImporterContext.hpp:116: Registering tensor: 393 for ONNX tensor: 393 [08/12/2020-10:24:43] [V] [TRT] ModelImporter.cpp:179: Reshape_204 [Reshape] outputs: [393 -> (-1, 32, -1)], [08/12/2020-10:24:43] [V] [TRT] ModelImporter.cpp:103: Parsing node: Constant_205 [Constant] [08/12/2020-10:24:43] [V] [TRT] ModelImporter.cpp:125: Constant_205 [Constant] inputs: [08/12/2020-10:24:43] [V] [TRT] ModelImporter.cpp:179: Constant_205 [Constant] outputs: [394 -> (32)], [08/12/2020-10:24:43] [V] [TRT] ModelImporter.cpp:103: Parsing node: Constant_206 [Constant] [08/12/2020-10:24:43] [V] [TRT] ModelImporter.cpp:125: Constant_206 [Constant] inputs: [08/12/2020-10:24:43] [V] [TRT] ModelImporter.cpp:179: Constant_206 [Constant] outputs: [395 -> (32)], [08/12/2020-10:24:43] [V] [TRT] ModelImporter.cpp:103: Parsing node: InstanceNormalization_207 [InstanceNormalization] [08/12/2020-10:24:43] [V] [TRT] ModelImporter.cpp:119: Searching for input: 393 [08/12/2020-10:24:43] [V] [TRT] ModelImporter.cpp:119: Searching for input: 394 [08/12/2020-10:24:43] [V] [TRT] ModelImporter.cpp:119: Searching for input: 395 [08/12/2020-10:24:43] [V] [TRT] ModelImporter.cpp:125: InstanceNormalization_207 [InstanceNormalization] inputs: [393 -> (-1, 32, -1)], [394 -> (32)], [395 -> (32)], ERROR: builtin_op_importers.cpp:1595 In function importInstanceNormalization: [8] Assertion failed: !isDynamic(tensorPtr->getDimensions()) && "InstanceNormalization does not support dynamic inputs!

@rajeevsrao
Copy link
Collaborator

@dechunwang @jinfagang yes, the fix will be in our next ONNX parser monthly update. Thanks.

@VincentGu11
Copy link

In TRT-7.1.3.4, I transfer FCOS with fpn also get the same problem:
ERROR: builtin_op_importers.cpp:1595 In function importInstanceNormalization:
[8] Assertion failed: !isDynamic(tensorPtr->getDimensions()) && "InstanceNormalization does not support dynamic inputs!"

@olferuk
Copy link

olferuk commented Nov 20, 2020

Any updates?

@kevinch-nv
Copy link
Collaborator

The InstanceNormalizationPlugin has been updated to support dynamic shapes in the latest TRT release (7.2). For those still following this thread please upgrade to TRT 7.2 and try importing your model again.

I will be closing this issue since this has been updated, if anyone is still having trouble with TRT 7.2 feel free to open a new issue.

@kevinch-nv kevinch-nv added enhancement New feature or request triaged Issue has been triaged by maintainers labels Dec 15, 2020
@VincentGu11
Copy link

I use TRT 7.2.2.3 to convert dynamic input InstanceNormalization, I still has the same problem, and my cuda is 10.2.

@kevinch-nv
Copy link
Collaborator

@VincentGu11 can you open a new issue about this and attach the model you are having trouble with?

@ThrillerWYY
Copy link

@VincentGu11 can you open a new issue about this and attach the model you are having trouble with?

hello,have you solved this trouble, i met the same problem,UNSUPPORTED_NODE: Assertion failed: !isDynamic(tensor_ptr->getDimensions()) && "InstanceNormalization does not support dynamic inputs!"

@ThrillerWYY
Copy link

I use TRT 7.2.2.3 to convert dynamic input InstanceNormalization, I still has the same problem, and my cuda is 10.2.

hi,have you solved this problem?

@ThrillerWYY
Copy link

CUDNN_VERSION=7.6.3.30-1+cuda10.0
NCCL_VERSION=2.5.6-1+cuda10.0
TRT_VERSION=7.0.0.11

my environment is as same as yours,how do you solve this problem?

@kevinch-nv
Copy link
Collaborator

As mentioned in my previous comment, this failure is expected with TRT 7.0 and has been fixed in TRT 7.2. Can you upgrade your TRT version?

CUDNN_VERSION=7.6.3.30-1+cuda10.0
NCCL_VERSION=2.5.6-1+cuda10.0
TRT_VERSION=7.0.0.11

my environment is as same as yours,how do you solve this problem?

@thishome
Copy link

thishome commented May 8, 2021

@kevinch-nv hi, i got the same problem, please tell me how can i update trt version to 7.2.0 or new , now it is version 7.1.3 installed by jetPack. Is there a way only to upgraded it on jetson xavier device.

@kevinch-nv
Copy link
Collaborator

kevinch-nv commented May 10, 2021

The changes require are a plugin change, I believe the core TRT version does not matter.

You can follow the Jetpack cross compilation instructions here: https://github.com/NVIDIA/TensorRT/ to build the updated plugins, and you should be able to drag and drop the newly built binaries into the Jetson device and have your application link to those.

@liuxufenfeiya
Copy link

liuxufenfeiya commented May 19, 2021

tensorrt 7.2.2,same error
onnx2trt_error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests