Skip to content

Commit

Permalink
[DML EP] Disable DML Graph Fusion for lower graph optimization level …
Browse files Browse the repository at this point in the history
…OR setOptimizedFilePath true (#13913)

### Description
DML EP won't fuse the ONNX Graph if ORT Graph optimization level is <= 1
or `SessionOption::SetOptimizedFilePath` is passed.

This is the successor of
#11346.

### Motivation and Context
- **Why is this change required? What problem does it solve?**  
Requested by few a users (issues below) and also helps in debugging.
- **If it fixes an open issue, please link to the issue here:**
  - #13535
  - #8440
  • Loading branch information
sumitsays committed Dec 12, 2022
1 parent 8cfbc4f commit fe827c3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,21 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling),
m_function(function)
{
DmlOperator::Initialize(kernelInfo);
const bool hasDilations =
std::any_of(
m_kernel.dilations,
m_kernel.dilations + m_kernel.spatialDimensionCount,
[](auto d) {return d != 1; }
);
const bool hasOutputIndices = (kernelInfo.GetOutputCount() > 1 && kernelInfo.IsOutputValid(1));
std::vector<std::optional<uint32_t>> kernelOutputIndices = {0};

if (function == DML_OPERATOR_MAX_POOLING2 && (hasOutputIndices || hasDilations))
{
kernelOutputIndices.emplace_back(1);
}
DmlOperator::Initialize(kernelInfo, std::nullopt, kernelOutputIndices);

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
ML_CHECK_VALID_ARGUMENT(inputDescs.size() >= 1, "MaxPool input count must be >=1.");
Expand All @@ -33,13 +46,6 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
int storageOrder = kernelInfo.GetOptionalAttribute<int>(AttrName::StorageOrder, 0);
ORT_THROW_HR_IF(E_NOTIMPL, storageOrder != 0);

const bool hasDilations =
std::any_of(
m_kernel.dilations,
m_kernel.dilations + m_kernel.spatialDimensionCount,
[](auto d) {return d != 1; }
);

// DML requires that DimensionCount be equal to Input.DimCount - 2 for Pooling
uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2;
if (m_kernel.spatialDimensionCount < expectedSpatialDimCount)
Expand Down Expand Up @@ -104,7 +110,6 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
case DML_OPERATOR_MAX_POOLING1:
case DML_OPERATOR_MAX_POOLING2:
{
bool hasOutputIndices = (outputDescs.size() > 1 && outputDescs[1].Desc != nullptr);
if (hasOutputIndices || hasDilations)
{
DML_MAX_POOLING2_OPERATOR_DESC desc = {};
Expand Down
25 changes: 16 additions & 9 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1403,19 +1403,26 @@ common::Status InferenceSession::Initialize() {

#ifdef USE_DML
if (execution_providers_.Get(kDmlExecutionProvider)) {
std::unique_ptr<onnxruntime::GraphTransformer> dmlGraphFusionTransformer = std::make_unique<Dml::DmlGraphFusionTransformer>("DmlGraphFusionTransformer",
execution_providers_.Get(kDmlExecutionProvider));
if (dmlGraphFusionTransformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr");
bool dml_graph_fusion_enabled = session_options_.optimized_model_filepath.empty() &&
session_options_.graph_optimization_level >= TransformerLevel::Level3;
if (dml_graph_fusion_enabled) {
std::unique_ptr<onnxruntime::GraphTransformer> dmlGraphFusionTransformer = std::make_unique<Dml::DmlGraphFusionTransformer>("DmlGraphFusionTransformer",
execution_providers_.Get(kDmlExecutionProvider));
if (dmlGraphFusionTransformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr");
}
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3));
}
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3));

// This transformer applies DML-specific fusions that go beyond what ORT offers by default
std::unique_ptr<onnxruntime::GraphTransformer> dmlOperatorFusionTransformer = std::make_unique<Dml::GraphTransformer>("DmlOperatorFusionTransformer");
if (dmlOperatorFusionTransformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr");
bool dml_operator_fusion_enabled = session_options_.graph_optimization_level >= TransformerLevel::Level2;
if (dml_operator_fusion_enabled) {
std::unique_ptr<onnxruntime::GraphTransformer> dmlOperatorFusionTransformer = std::make_unique<Dml::GraphTransformer>("DmlOperatorFusionTransformer");
if (dmlOperatorFusionTransformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr");
}
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlOperatorFusionTransformer), onnxruntime::TransformerLevel::Level2));
}
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformation_mgr_.Register(std::move(dmlOperatorFusionTransformer), onnxruntime::TransformerLevel::Level2));
}
#endif

Expand Down

0 comments on commit fe827c3

Please sign in to comment.