diff --git a/onnxruntime/core/optimizer/bias_dropout_fusion.cc b/onnxruntime/core/optimizer/bias_dropout_fusion.cc index 442a9ae88a915..e27750f6ae0bb 100644 --- a/onnxruntime/core/optimizer/bias_dropout_fusion.cc +++ b/onnxruntime/core/optimizer/bias_dropout_fusion.cc @@ -15,49 +15,56 @@ void FuseResidualAddIfAny(Graph& graph, const Node& dropout_node, std::vector& dropout_output, std::vector>& nodes_to_fuse) { bool has_residual_add = false; - for (auto last_node_itr = dropout_node.OutputNodesBegin(); last_node_itr != dropout_node.OutputNodesEnd(); ++last_node_itr) { - const Node& last_node = (*last_node_itr); - - if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7, 13}) && - last_node.GetExecutionProviderType() == dropout_node.GetExecutionProviderType()) { - const TensorShapeProto* input1_shape = last_node.InputDefs()[0]->Shape(); - const TensorShapeProto* input2_shape = last_node.InputDefs()[1]->Shape(); - - if (input1_shape == nullptr || - input2_shape == nullptr || - input1_shape->dim_size() < 1 || - input2_shape->dim_size() < 1 || - input1_shape->dim_size() != input2_shape->dim_size()) { - continue; - } - - // Inputs of Residual Add must match in shape - bool match = true; - for (int i = 0; i < input1_shape->dim_size(); ++i) { - match &= ONNX_NAMESPACE::operator==(input1_shape->dim(i), input2_shape->dim(i)); - } - if (!match) { - continue; - } - - // dropout's output is not part of of graph output - if (!graph.GetNodeOutputsInGraphOutputs(dropout_node).empty()) { - continue; - } - Node& residual_add_node = *graph.GetNode(last_node.Index()); - const std::string& dropout_output_name = dropout_node.OutputDefs()[0]->Name(); - if (dropout_output_name == residual_add_node.InputDefs()[0]->Name()) { - dropout_input.push_back(residual_add_node.MutableInputDefs()[1]); // residual - } else if (dropout_output_name == residual_add_node.InputDefs()[1]->Name()) { - dropout_input.push_back(residual_add_node.MutableInputDefs()[0]); // residual + int dropout_consumers_count = 0; + for (auto edge_itr = dropout_node.OutputEdgesBegin(); edge_itr != dropout_node.OutputEdgesEnd(); ++edge_itr) { + if (edge_itr->GetSrcArgIndex() == 0) { + ++dropout_consumers_count; + } + } + // To be able to fuse the residual Add, + // the Dropout's output must not be a graph output and + // there must be only one consumer of the Dropout's first output. + if (dropout_consumers_count < 2 && graph.GetNodeOutputsInGraphOutputs(dropout_node).empty()) { + for (auto last_node_itr = dropout_node.OutputNodesBegin(); last_node_itr != dropout_node.OutputNodesEnd(); ++last_node_itr) { + const Node& last_node = (*last_node_itr); + + if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7, 13}) && + last_node.GetExecutionProviderType() == dropout_node.GetExecutionProviderType()) { + const TensorShapeProto* input1_shape = last_node.InputDefs()[0]->Shape(); + const TensorShapeProto* input2_shape = last_node.InputDefs()[1]->Shape(); + + if (input1_shape == nullptr || + input2_shape == nullptr || + input1_shape->dim_size() < 1 || + input2_shape->dim_size() < 1 || + input1_shape->dim_size() != input2_shape->dim_size()) { + continue; + } + + // Inputs of Residual Add must match in shape + bool match = true; + for (int i = 0; i < input1_shape->dim_size(); ++i) { + match &= ONNX_NAMESPACE::operator==(input1_shape->dim(i), input2_shape->dim(i)); + } + if (!match) { + continue; + } + + Node& residual_add_node = *graph.GetNode(last_node.Index()); + const std::string& dropout_output_name = dropout_node.OutputDefs()[0]->Name(); + if (dropout_output_name == residual_add_node.InputDefs()[0]->Name()) { + dropout_input.push_back(residual_add_node.MutableInputDefs()[1]); // residual + } else if (dropout_output_name == residual_add_node.InputDefs()[1]->Name()) { + dropout_input.push_back(residual_add_node.MutableInputDefs()[0]); // residual + } + + dropout_output[0] = residual_add_node.MutableOutputDefs()[0]; + + nodes_to_fuse.push_back(residual_add_node); + has_residual_add = true; + break; } - - dropout_output[0] = residual_add_node.MutableOutputDefs()[0]; - - nodes_to_fuse.push_back(residual_add_node); - has_residual_add = true; - break; } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index c31267b1baf7a..d0b0a47e58c5f 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2944,6 +2944,8 @@ TEST_F(GraphTransformationTests, BiasDropoutFusionTest) { TestBiasDropoutFusion(MODEL_FOLDER "fusion/bias_dropout_residual_fusion1.onnx", *logger_); TestBiasDropoutFusion(MODEL_FOLDER "fusion/bias_dropout_residual_fusion2.onnx", *logger_); TestBiasDropoutFusion(MODEL_FOLDER "fusion/bias_dropout_residual_fusion_mismatch.onnx", *logger_, 1); + TestBiasDropoutFusion(MODEL_FOLDER "fusion/bias_dropout_residual_fusion_multiple_consumers1.onnx", *logger_, 1); + TestBiasDropoutFusion(MODEL_FOLDER "fusion/bias_dropout_residual_fusion_multiple_consumers2.onnx", *logger_, 1); } TEST_F(GraphTransformationTests, LayerNormFusionTest) { diff --git a/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion_multiple_consumers1.onnx b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion_multiple_consumers1.onnx new file mode 100644 index 0000000000000..48649903119c1 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion_multiple_consumers1.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion_multiple_consumers2.onnx b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion_multiple_consumers2.onnx new file mode 100644 index 0000000000000..dc6e6008faa5b Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion_multiple_consumers2.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py index 206e44af83b52..b0642194c6a2a 100644 --- a/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py @@ -96,4 +96,40 @@ [ratio, training_mode]) model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_residual_fusion_mismatch.onnx') \ No newline at end of file +onnx.save(model, 'bias_dropout_residual_fusion_mismatch.onnx') + +# If the Dropout output 0 is also a graph output, the residual Add shouldn't be fused. +# Create the model (ModelProto) +bias = helper.make_node("Add", ["B", "A"], ["add0_out"], "add0") +dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +residual = helper.make_node("Add", ["R", "dropout_out"], ["C"], "add1") + +D = helper.make_tensor_value_info('dropout_out', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]) + +graph = helper.make_graph( + [bias, dropout_12, residual], + "Bias_Dropout_Fusion", #name + [A, B, R], + [C, D], + [ratio, training_mode]) + +model = helper.make_model(graph, producer_name='onnx-example', **kwargs) +onnx.save(model, 'bias_dropout_residual_fusion_multiple_consumers1.onnx') + +# If the Dropout has multiple consumers of output 0, the residual Add shouldn't be fused. +# Create the model (ModelProto) +D = helper.make_tensor_value_info('D', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]) +bias = helper.make_node("Add", ["B", "A"], ["add0_out"], "add0") +dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +residual = helper.make_node("Add", ["R", "dropout_out"], ["C"], "add1") +identity = helper.make_node("Identity", ["dropout_out"], ["D"], "identity") + +graph = helper.make_graph( + [bias, dropout_12, residual, identity], + "Bias_Dropout_Fusion", #name + [A, B, R], + [C, D], + [ratio, training_mode]) + +model = helper.make_model(graph, producer_name='onnx-example', **kwargs) +onnx.save(model, 'bias_dropout_residual_fusion_multiple_consumers2.onnx')