-
Notifications
You must be signed in to change notification settings - Fork 547
Description
When converting a onnx model to trt model, one node is Gather. The conversion would get an error : Assertion failed: axis >= 0 && axis < nbDims.
The conversion code is in builtin_op_importers.cpp, the code is :
#if NV_TENSORRT_MAJOR >= 4
DEFINE_BUILTIN_OP_IMPORTER(Gather) {
nvinfer1::ITensor& data = convertToTensor(inputs.at(0), ctx);
nvinfer1::ITensor& indices = convertToTensor(inputs.at(1), ctx);
OnnxAttrs attrs(node);
int axis = attrs.get("axis", 0);
int nbDims = inputs.at(0).shape().nbDims;
TRT_CHECK(convert_axis(axis, nbDims));
RETURN_FIRST_OUTPUT(ctx->network()->addGather(data, indices, axis));
}
#endif // NV_TENSORRT_MAJOR >= 4
We can find a convert_axis() method in the above code. In onnx2trt_utils.hpp, I found the source code
// Convert an ONNX axis into a TRT axis
inline Status convert_axis(int& axis, int nbDims)
{
// Support negative indexing
if (axis < 0)
{
axis += nbDims;
}
// If axis was positive, subtract 1 to strip batch dimension
else
{
axis = axis - 1;
}
ASSERT(axis >= 0 && axis < nbDims, ErrorCode::kUNSUPPORTED_NODE);
return Status::success();
}
However, my input here is axis = 0, nbDims = 1, then axis = axis -1 = -1. The assertion fails.
The question is how to handle the situation when axis = 0.
I am not clear about the axis transformation between onnx and tensorRT.
Thanks.