Skip to content

Commit

Permalink
Enable support of multi-level nested control flow ops model for TRT EP (
Browse files Browse the repository at this point in the history
#12147)

* Make multiple-level nested control flow op model work

* find correct input index

* find correct input index (cont.)

* enable nested layer unit tests for TRT EP

* add comment

* add Scan op to current workaround support of control flow op
  • Loading branch information
chilo-ms committed Aug 2, 2022
1 parent de3a91d commit b39257a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
29 changes: 27 additions & 2 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,14 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
if (node->GetOutputEdgesCount() > node->OutputDefs().size()) {
for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) {
const auto& node_idx = it->GetNode().Index();
const auto& output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()];
const onnxruntime::NodeArg* output;
// The dst_arg_index from GetDstArgIndex() could be the index for explicit/implicit input defs of the node.
// We need to get the correct input index accordingly. (See Graph::BuildConnections() in graph.cc for more details)
if (it->GetDstArgIndex() < static_cast<int>(it->GetNode().InputDefs().size())) {
output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()];
} else {
output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast<int>(it->GetNode().InputDefs().size())];
}
if (node_set.find(node_idx) != node_set.end()) {
const auto& iter = fused_inputs.find(output);
if (iter != fused_inputs.end()) {
Expand Down Expand Up @@ -897,6 +904,10 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t&
input_to_nodes_map[input->Name()].insert(node_name);
}

for (const auto& input : node->ImplicitInputDefs()) {
input_to_nodes_map[input->Name()].insert(node_name);
}

for (const auto& output : node->OutputDefs()) {
node_to_outputs_map[node_name].insert(output->Name());
}
Expand Down Expand Up @@ -970,7 +981,21 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
const int number_of_ort_nodes = graph.NumberOfNodes();
std::vector<size_t> nodes_vector(number_of_ort_nodes);
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);
SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{nodes_vector, false}};
std::vector<size_t> filtered_nodes_vector;
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();

// We currently exclude "If" and "Loop" control flow ops from original node vector before calling TensorRT parser.
// The reason is, these control flow ops have subgraph which might contain TRT fused node after ORT partition.
// If this is the case, TensorRT parser will complain the non-recognized TRT fused node and fail.
for (const auto& index : nodes_vector) {
const auto& node = graph.GetNode(node_index[index]);
if (node->OpType() == "If" || node->OpType() == "Loop" || node->OpType() == "Scan") {
continue;
}
filtered_nodes_vector.push_back(index);
}

SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{filtered_nodes_vector, false}};
bool early_termination = false;
supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination);
if (early_termination) {
Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
#include "core/providers/cuda/cuda_provider_factory.h"
#include "core/providers/cuda/gpu_data_transfer.h"
#endif
#ifdef USE_TENSORRT
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#endif
#ifdef USE_ROCM
#include "core/providers/rocm/rocm_provider_factory.h"
#include "core/providers/rocm/gpu_data_transfer.h"
Expand Down Expand Up @@ -1447,6 +1450,7 @@ TEST(InferenceSessionTests, Test3LayerNestedSubgraph) {
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
ONNX_NAMESPACE::TypeProto bool_tensor;
bool_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL);
bool_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);

auto& if_cond_input = graph.GetOrCreateNodeArg("if_cond_input", &bool_tensor);
auto& graph_if_input = graph.GetOrCreateNodeArg("graph_if_input", nullptr);
Expand Down Expand Up @@ -1485,7 +1489,9 @@ TEST(InferenceSessionTests, Test3LayerNestedSubgraph) {
so.session_logid = "InferenceSessionTests.Test3LayerNestedSubgraph";
InferenceSession session_object{so, GetEnvironment()};

#if USE_CUDA
#if USE_TENSORRT
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultTensorrtExecutionProvider()));
#elif USE_CUDA
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider()));
#elif USE_ROCM
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider()));
Expand Down Expand Up @@ -1562,11 +1568,13 @@ TEST(InferenceSessionTests, Test2LayerNestedSubgraph) {

ONNX_NAMESPACE::TypeProto float_tensor_input;
float_tensor_input.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
float_tensor_input.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
ONNX_NAMESPACE::TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("__graph_0__float_unknown");
ONNX_NAMESPACE::TypeProto bool_tensor;
bool_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL);
bool_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);

// graph inputs
auto& input_0 = graph.GetOrCreateNodeArg("input_0", &float_tensor_input);
Expand Down Expand Up @@ -1617,7 +1625,9 @@ TEST(InferenceSessionTests, Test2LayerNestedSubgraph) {
so.session_logid = "InferenceSessionTests.Test2LayerNestedSubgraph";
InferenceSession session_object{so, GetEnvironment()};

#if USE_CUDA
#if USE_TENSORRT
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultTensorrtExecutionProvider()));
#elif USE_CUDA
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider()));
#elif USE_ROCM
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider()));
Expand Down

0 comments on commit b39257a

Please sign in to comment.