Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ScatterElements with Reduction (opset 16) Not Fully Supported #953

Open
anthony-correia opened this issue Jan 10, 2024 · 0 comments
Open

Comments

@anthony-correia
Copy link

anthony-correia commented Jan 10, 2024

Short Description

Conversion of an ONNX model to TensorRT using trtexec, which includes a scatterElements operation with a reduction like "sum" (opset 16), fails when the number of indices in the operation exceeds the output count.

Successful conversion requires n_indices <= n_outputs.

Long Description

Consider the following PyTorch model snippet:

import torch
import torch_scatter

n_indices: int = ...
dim_size: int = ...
n_outputs: int = ...

e_dummy = torch.randn(size=(n_indices, dim_size), device=device)
index_dummy = torch.randint(high=n_outputs, size=(n_indices,), device=device)

class ScatterModule(torch.nn.Module):
    def forward(self, e: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
        return torch_scatter.scatter(
            src=e,
            # broadcasting (should be done automatically anyway)
            index=index.unsqueeze(-1).expand(-1, e.shape[1]),
            dim=0,
            reduce="sum",
        )

Converting this corrresponding ONNX model using trtexec triggers an assertion error:

Assertion failed: indicesDims.d[i] <= dataDims.d[i] && "Indices dimensions must be less than data dimensions!"

This error likely originates from this line of the ONNX-TensorRT code.

In the scenarios I've encountered within Graph Neural Networks, the number of indices (n_indices, corresponding to the edges in the graph) is significantly larger than the number of outputs (n_outputs, corresponding to the nodes in the graph).

Environment

TensorRT Version: 8.6.1.6-1+cuda11.8
GPU Type: NVIDIA RTX A2000 (laptop)
Nvidia Driver Version: 520.61.05
CUDA Version: 11.8.0-1
CUDNN Version: 8.7.0.84-1+cuda11.8
Operating System + Version: Ubuntu 22.04.1 LTS

Relevant Files

I've created a repository to reproduce the issue: anthony-correia/scatter_onnx2tensorrt.

The ONNX models are stored with the naming convention onnx/{n_indices}_{dim_size}_{n_outputs}_{seed}.onnx.

To replicate the issue, execute the following commands:

# This command fails when `n_outputs = 100` and `n_indices = 1000`.
trtexec --onnx="onnx/1000_3_100_0.onnx"

# This command succeeds when `n_outputs` equals `n_indices` (both are 100).
trtexec --onnx="onnx/100_3_100_0.onnx"
@anthony-correia anthony-correia changed the title scatter_add Conversion Issue in TensorRT ScatterElements With Reduction Issue in TensorRT Jan 10, 2024
@anthony-correia anthony-correia changed the title ScatterElements With Reduction Issue in TensorRT ScatterElements with Reduction (opset 16) Not Fully Supported Jan 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant