Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8] Assertion failed: ctx->network()->hasExplicitPrecision() && "TensorRT only supports multi-input conv for explicit precision QAT networks!" #645

Closed
lucasjinreal opened this issue Feb 19, 2021 · 6 comments
Labels
duplicate This issue or pull request already exists triaged Issue has been triaged by maintainers

Comments

@lucasjinreal
Copy link
Contributor

Hi, I try to convert a model of onnx with normal data type float32 not QAT models. But it gives me this error message:

[8] Assertion failed: ctx->network()->hasExplicitPrecision() && "TensorRT only supports multi-input conv for explicit precision QAT networks!"

And I can reproduce this error with this minimal code:

class MG(nn.Module):

    def __init__(self):
        super().__init__()
        # for test if torch.cat([bool, bool]) can convert

    def forward(self, x, b):
        # x, b = x
        preds = F.conv2d(x, b,
                             stride=1)
        preds = preds.to(torch.float)
        preds = preds.sigmoid().float()
        seg_masks = preds > torch.tensor(0.03, dtype=torch.float)
        return seg_masks


torch_model = MG()
x = torch.randn([1, 4, 24, 24])
b = torch.randn([8, 4, 3, 3])
torch_out = torch_model(x, b)

# Export the model
torch.onnx.export(torch_model,               # model being run
                  (x, b),
                  "a.onnx",
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  do_constant_folding=True,
                  verbose=True)
print('Done!')

If you export onnx with pytorch 1.7, and try convert to trt engine, it will shows this error:

[8] Assertion failed: ctx->network()->hasExplicitPrecision() && "TensorRT only supports multi-input conv for explicit precision QAT networks!"

You might will ask why using torch.tensor(0.03, dtype=torch.float) in > op, it was because if not, it will cast float to double and invoke a double data type in onnx.

Which will make onnx2trt raise another error called unsupported datatype 11.

So how should we solve this awkward situation?

@lucasjinreal
Copy link
Contributor Author

@mk-nvidia To be more specific, is this problem:

class MG(nn.Module):

    def __init__(self):
        super().__init__()
        # for test if torch.cat([bool, bool]) can convert

    def forward(self, x, b):
        preds = F.conv2d(x, b,
                             stride=1)
        return preds


torch_model = MG()
x = torch.randn([1, 4, 24, 24])
b = torch.randn([8, 4, 3, 3])
torch_out = torch_model(x, b)

# Export the model
torch.onnx.export(torch_model,               # model being run
                  (x, b),
                  "a.onnx",
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  do_constant_folding=True,
                  verbose=True)
print('Done!')

The root reason is this op can not be exported preds = F.conv2d(x, b, stride=1)

Which we always use in many mordern models such as SOLOv2, dynamic Convolutions.

@lucasjinreal
Copy link
Contributor Author

@mk-nvidia Pls look:

image

[8] Assertion failed: ctx->network()->hasExplicitPrecision() && "TensorRT only supports multi-input conv for explicit precision QAT networks!"

@zhenhuaw-me
Copy link
Member

@jinfagang TensorRT requires that the second input of the Conv is initializer in the ONNX model. It seems to me that your model can be optimized with ONNX Optimizer and ONNX Simplifier. Could you try the tools then run TensorRT with the processed ONNX model?

@kochsebastian
Copy link

kochsebastian commented Jun 7, 2021

@jackwish
I get the same error with this simple model:

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        pass

    def forward(self, input,weight,grad_output):
        return F.grad.conv2d_weight(input, weight.shape, grad_output)

dw = Net()
    
input = torch.randn(1,1,3,3, requires_grad=True)
weight = torch.randn(1,1,1,2, requires_grad=True)
output = F.conv2d(input, weight)
grad_output = torch.randn(output.shape)/home/skoch/saliency_eval/test_conv_derivative.onnx
grad_input = torch.autograd.grad(output, input, grad_output)
out = dw(input, weight, grad_output)
torch.onnx.export(dw, (input, weight, grad_output), 'test_conv_derivative.onnx', verbose=True,opset_version=11)

test_conv_derivative.zip
Neither onnx-optimizer nor onnx-simplifier worked for me

Conv2d derivative source code:

def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):

    stride = _pair(stride)
    padding = _pair(padding)
    dilation = _pair(dilation)
    in_channels = input.shape[1]
    out_channels = grad_output.shape[1]
    min_batch = input.shape[0]

    grad_output = grad_output.contiguous().repeat(1, in_channels // groups, 1,
                                                  1)
    grad_output = grad_output.contiguous().view(
        grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2],
        grad_output.shape[3])

    input = input.contiguous().view(1, input.shape[0] * input.shape[1],
                                    input.shape[2], input.shape[3])

    grad_weight = torch.conv2d(input, grad_output, None, dilation, padding,
                               stride, in_channels * min_batch)

    grad_weight = grad_weight.contiguous().view(
        min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2],
        grad_weight.shape[3])

    return grad_weight.sum(dim=0).view(
        in_channels // groups, out_channels,
        grad_weight.shape[2], grad_weight.shape[3]).transpose(0, 1).narrow(
            2, 0, weight_size[2]).narrow(3, 0, weight_size[3])

Edit: I replaced the functional conv2d with nn.conv2d where I set the weights during the inference path. But unfortunately it didn't work

@joan126
Copy link

joan126 commented Feb 8, 2022

I also met this issue, have you solved it ? @jinfagang

@kevinch-nv
Copy link
Collaborator

Closing as duplicate of #609

@kevinch-nv kevinch-nv added duplicate This issue or pull request already exists triaged Issue has been triaged by maintainers labels Mar 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
duplicate This issue or pull request already exists triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

5 participants