Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion BUILD.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ See more information on the TensorRT Execution Provider [here](./docs/execution_
* The path to the CUDA `bin` directory must be added to the PATH environment variable so that `nvcc` is found.
* The path to the cuDNN installation (path to folder that contains libcudnn.so) must be provided via the cuDNN_PATH environment variable, or `--cudnn_home parameter`.
* Install [TensorRT](https://developer.nvidia.com/nvidia-tensorrt-download)
* The TensorRT execution provider for ONNX Runtime is built and tested with TensorRT 6.0.1.5.
* The TensorRT execution provider for ONNX Runtime is built and tested with TensorRT 6.0.1.5 but validated with the feature set equivalent to TensorRT 5. Some TensorRT 6 new features such as dynamic shape is not available at this time.
* The path to TensorRT installation must be provided via the `--tensorrt_home parameter`.

#### Build Instructions
Expand Down
2 changes: 2 additions & 0 deletions docs/execution_providers/TensorRT-ExecutionProvider.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ With the TensorRT execution provider, the ONNX Runtime delivers better inferenci
## Build
For build instructions, please see the [BUILD page](../../BUILD.md#tensorrt).

The TensorRT execution provider for ONNX Runtime is built and tested with TensorRT 6.0.1.5 but validated with the feature set equivalent to TensorRT 5. Some TensorRT 6 new features such as dynamic shape is not available as this time.

## Using the TensorRT execution provider
### C/C++
The TensortRT execution provider needs to be registered with ONNX Runtime to enable in the inference session.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,11 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect

SubGraphCollection_t next_nodes_list;
const onnxruntime::GraphViewer graph_viewer(graph_build);
const std::vector<NodeIndex>& subgraph_node_index = graph_viewer.GetNodesInTopologicalOrder();
next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, graph_viewer, early_termination);
for (int i = 0, end = next_nodes_list.size(); i < end; ++i) {
for (int j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) {
next_nodes_list[i].first[j] = group.first[next_nodes_list[i].first[j]];
next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]];
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unit test to validate this change?

nodes_list_output.push_back(next_nodes_list[i]);
}
Expand Down
115 changes: 115 additions & 0 deletions onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,120 @@ TEST(TensorrtExecutionProviderTest, FunctionTest) {
ASSERT_TRUE(status.IsOK());
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
}

TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) {
onnxruntime::Model model("graph_1");
auto& graph = model.MainGraph();
std::vector<onnxruntime::NodeArg*> inputs;
std::vector<onnxruntime::NodeArg*> outputs;

// FLOAT tensor.
ONNX_NAMESPACE::TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2);

// BOOL tensor.
ONNX_NAMESPACE::TypeProto bool_tensor;
bool_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL);
bool_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
bool_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
bool_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2);

// UINT8 tensor.
ONNX_NAMESPACE::TypeProto uint8_tensor;
uint8_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
uint8_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
uint8_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
uint8_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2);

auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &bool_tensor);
inputs.push_back(&input_arg_1);
auto& output_arg_1 = graph.GetOrCreateNodeArg("node_1_out", &uint8_tensor);
outputs.push_back(&output_arg_1);
auto& cast_node = graph.AddNode("cast1", "Cast", "node 1.", inputs, outputs);
AttributeProto attr_proto;
attr_proto.set_name("to");
attr_proto.set_type(AttributeProto_AttributeType_INT);
attr_proto.set_i(2);
cast_node.AddAttribute("to", attr_proto);

inputs.clear();
inputs.push_back(&output_arg_1);
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &bool_tensor);
outputs.clear();
outputs.push_back(&output_arg_2);
auto& cast_node_2 = graph.AddNode("cast2", "Cast", "node 2.", inputs, outputs);
AttributeProto attr_proto_2;
attr_proto_2.set_name("to");
attr_proto_2.set_type(AttributeProto_AttributeType_INT);
attr_proto_2.set_i(9);
cast_node_2.AddAttribute("to", attr_proto_2);

auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor);
inputs.clear();
inputs.push_back(&input_arg_2);
inputs.push_back(&input_arg_3);
auto& output_arg_3 = graph.GetOrCreateNodeArg("N", &float_tensor);
outputs.clear();
outputs.push_back(&output_arg_3);
graph.AddNode("sub", "Sub", "node 3.", inputs, outputs);

auto status = graph.Resolve();
ASSERT_TRUE(status.IsOK());
std::string model_file_name = "trt_execution_provider_NodeIndexMappingTest.onnx";
status = onnxruntime::Model::Save(model, model_file_name);

std::vector<int64_t> dims_mul_x = {1, 3, 2};
std::vector<bool> values_mul_x = {true, false, true, false, true, false};
std::vector<int64_t> dims_mul_y = {1, 3, 2};
std::vector<float> values_mul_y = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
OrtValue ml_value_x;
CreateMLValue<bool>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_x, values_mul_x, &ml_value_x);
OrtValue ml_value_y;
CreateMLValue<float>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_y, values_mul_y, &ml_value_y);
OrtValue ml_value_z;
CreateMLValue<float>(TestTensorrtExecutionProvider()->GetAllocator(0, OrtMemTypeCPU), dims_mul_y, values_mul_y, &ml_value_z);
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value_x));
feeds.insert(std::make_pair("Y", ml_value_y));
feeds.insert(std::make_pair("Z", ml_value_z));

// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("M");
output_names.push_back("N");
std::vector<OrtValue> fetches;

// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_mul_m = {1, 3, 2};
std::vector<bool> expected_values_mul_m = {true, false, true, false, true, false};
std::vector<int64_t> expected_dims_mul_n = {1, 3, 2};
std::vector<float> expected_values_mul_n = {0, 0, 0, 0, 0, 0};

SessionOptions so;
so.session_logid = "TensorrtExecutionProviderTest.NodeIndexMappingTest";
RunOptions run_options;
run_options.run_tag = so.session_logid;

InferenceSession session_object{so};

TensorrtExecutionProviderInfo epi;
epi.device_id = 0;
EXPECT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::TensorrtExecutionProvider>(epi)).IsOK());

status = session_object.Load(model_file_name);
ASSERT_TRUE(status.IsOK());
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK());

// Now run
status = session_object.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(status.IsOK());
std::vector<OrtValue> fetche {fetches.back()};
VerifyOutputs(fetche, expected_dims_mul_n, expected_values_mul_n);
}
} // namespace test
} // namespace onnxruntime