You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
for (int32_t i = 0; i < dataDims.nbDims; ++i) { if (indicesDims.d[i] != -1 && dataDims.d[i] != -1) { ASSERT(indicesDims.d[i] <= dataDims.d[i] && "Indices dimensions must be less than data dimensions!", ErrorCode::kUNSUPPORTED_NODE); } if (updatesDims.d[i] != -1 && dataDims.d[i] != -1) { ASSERT(updatesDims.d[i] <= dataDims.d[i] && "Updates dimensions must be less than data dimensions!", ErrorCode::kUNSUPPORTED_NODE); } }
In this section, the assertion indicesDims.d[i] <= dataDims.d[i] should be changed into indices.max() <= dataDims.d[i]. the assertion updatesDims.d[i] <= dataDims.d[i] should be changed into updates.max() <= dataDims.d[i].
This is because that scatter cares about the index instead of the values, hence, dataDims can be irrelevant with indicesDims/updatesDims, besides that their ranks should be the same.
Environment
TensorRT Version: 8.6.1.6 ONNX-TensorRT Version / Branch: 1.16 GPU Type: RTX 4070 Nvidia Driver Version: 535.146.02 CUDA Version: 12.2 CUDNN Version: 11.8 Operating System + Version: Ubuntu 18.4 Python Version (if applicable): 3.10 TensorFlow + TF2ONNX Version (if applicable): PyTorch Version (if applicable): 2.2.1 Baremetal or Container (if container which image + tag):
Relevant Files
No relevant files.
Steps To Reproduce
You could reproduce it by converting the following into an onnx and then into a trt engine: y = torch.scatter_add(input=y_, dim=1, index=x, src=x_)
where y_ is of the shape of (1, 20), x is of the shape of (1, 4096) with the maximum value <= 19, x_ is of the shape of (1, 4096) and filled with 1's.
You will see that it can be converted into onnx but fails at trt engine.
The text was updated successfully, but these errors were encountered:
Description
for (int32_t i = 0; i < dataDims.nbDims; ++i) { if (indicesDims.d[i] != -1 && dataDims.d[i] != -1) { ASSERT(indicesDims.d[i] <= dataDims.d[i] && "Indices dimensions must be less than data dimensions!", ErrorCode::kUNSUPPORTED_NODE); } if (updatesDims.d[i] != -1 && dataDims.d[i] != -1) { ASSERT(updatesDims.d[i] <= dataDims.d[i] && "Updates dimensions must be less than data dimensions!", ErrorCode::kUNSUPPORTED_NODE); } }
In this section, the assertion
indicesDims.d[i] <= dataDims.d[i]
should be changed intoindices.max() <= dataDims.d[i]
. the assertionupdatesDims.d[i] <= dataDims.d[i]
should be changed intoupdates.max() <= dataDims.d[i]
.This is because that scatter cares about the index instead of the values, hence, dataDims can be irrelevant with indicesDims/updatesDims, besides that their ranks should be the same.
Environment
TensorRT Version: 8.6.1.6
ONNX-TensorRT Version / Branch: 1.16
GPU Type: RTX 4070
Nvidia Driver Version: 535.146.02
CUDA Version: 12.2
CUDNN Version: 11.8
Operating System + Version: Ubuntu 18.4
Python Version (if applicable): 3.10
TensorFlow + TF2ONNX Version (if applicable):
PyTorch Version (if applicable): 2.2.1
Baremetal or Container (if container which image + tag):
Relevant Files
No relevant files.
Steps To Reproduce
You could reproduce it by converting the following into an onnx and then into a trt engine:
y = torch.scatter_add(input=y_, dim=1, index=x, src=x_)
where y_ is of the shape of (1, 20), x is of the shape of (1, 4096) with the maximum value <= 19, x_ is of the shape of (1, 4096) and filled with 1's.
You will see that it can be converted into onnx but fails at trt engine.
The text was updated successfully, but these errors were encountered: