From e528c43054c3c63ad513cb2adf69fa854ed78ead Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Wed, 11 Jan 2023 16:13:55 -0800 Subject: [PATCH 1/5] Enable VerifyBackendContract in LTC backend --- .../Dialect/Torch/Transforms/Passes.td | 4 +- .../mlir_lowering_context.cpp | 53 ++++++++++++++++--- .../base_lazy_backend/mlir_lowering_context.h | 8 +-- 3 files changed, 53 insertions(+), 12 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 5dcf2286b4cc..1400250f878d 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -348,7 +348,7 @@ def VerifyBackendContract let summary = "Check that program satisfies backend contract."; let constructor = [{ mlir::torch::Torch::createVerifyBackendContractPass( - /*decompose=*/true, /*backendLegalOps=*/{}) + /*decompose=*/false, /*backendLegalOps=*/{}) }]; let description = [{ This pass performs a set of inspections to check that program satisfies backend @@ -356,7 +356,7 @@ def VerifyBackendContract `signalPassFailure()` status. }]; let options = [ - Option<"decompose", "decompose", "bool", /*default=*/"true", + Option<"decompose", "decompose", "bool", /*default=*/"false", "Decompose ops.">, ListOption<"backendLegalOps", "backend-legal-ops", "std::string", "List of ops to be considered legal for the backend."> diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index ec234dc77e02..9ffb4cb1def7 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -14,8 +14,12 @@ #include #include +#include #include #include "torch-mlir-c/Registration.h" +#include "torch-mlir-c/Transforms.h" +#include "mlir-c/IR.h" +#include "mlir-c/Pass.h" #include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h" #include "backend_impl.h" @@ -135,6 +139,11 @@ ComputationPtr TorchMlirLoweringContext::Build() { graph_->block()->registerOutput(output); } + // During operations lowering JIT may insert ScalarImplicit ops which output + // type !torch.number doesn't represent any existing MLIR type and should be + // refined either to Torch::IntType or Torch::FloatType. + torch::jit::ConvertScalarImplicit(graph_); + // Generate MLIR. MlirOperation func_op = torch_mlir::importJitFunctionAsFuncOp( /*context=*/mlir_context_, @@ -142,12 +151,35 @@ ComputationPtr TorchMlirLoweringContext::Build() { /*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; }, /*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true}); - return CreateComputation(func_op); + + // Convert MlirOperation to MlirModule. + MlirLocation loc = mlirLocationUnknownGet(mlir_context_); + MlirModule module_op = mlirModuleCreateEmpty(loc); + MlirBlock block = mlirModuleGetBody(module_op); + mlirBlockAppendOwnedOperation(block, func_op); + + // Apply passes to verify generated MLIR. + auto pass_manager = mlirPassManagerCreate(mlir_context_); + mlirPassManagerAddOwnedPass( + pass_manager, + mlirCreateVerifyBackendContract() + ); + + MlirLogicalResult result = mlirPassManagerRun( + pass_manager, + module_op + ); + + if (mlirLogicalResultIsFailure(result)) { + throw std::runtime_error("MLIR verification has failed."); + } + + return CreateComputation(module_op); } -ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirOperation func_op) { +ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirModule module_op) { return std::make_shared( - func_op, mlir_context_, graph_, parameter_names_, input_output_aliases_); + module_op, mlir_context_, graph_, parameter_names_, input_output_aliases_); } torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) { @@ -295,11 +327,11 @@ void TorchMlirLoweringContext::RegisterMlirDialects() { /////////////////////////////////////////////////////////////////////////////// TorchMlirComputation::TorchMlirComputation( - MlirOperation func_op, MlirContext mlir_context, + MlirModule module_op, MlirContext mlir_context, const std::shared_ptr& graph, std::unordered_map parameters_map, InputOutputAliases input_output_aliases) - : func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)), + : module_op_(std::move(module_op)), mlir_context_(std::move(mlir_context)), graph_(graph), input_output_aliases_(input_output_aliases), parameters_map_(parameters_map) { @@ -340,7 +372,14 @@ std::shared_ptr TorchMlirComputation::graph() const { return graph_; } -MlirOperation TorchMlirComputation::func_op() const { return func_op_; } +MlirOperation TorchMlirComputation::func_op() const { + MlirBlock block = mlirModuleGetBody(module_op_); + return mlirBlockGetFirstOperation(block); +} + +MlirModule TorchMlirComputation::module_op() const { + return module_op_; +} MlirContext TorchMlirComputation::mlir_context() const { return mlir_context_; @@ -385,7 +424,7 @@ const std::string TorchMlirComputation::to_string() const { *ss_ptr << std::string(part.data, part.length); }; std::stringstream ss; - mlirOperationPrint(func_op_, print_callback, &ss); + mlirOperationPrint(mlirModuleGetOperation(module_op_), print_callback, &ss); return ss.str(); } diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h index 61e18f4106c9..f62a71ce7945 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h @@ -73,7 +73,7 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { // embedded builder (returned by the builder() API). torch::lazy::ComputationPtr Build() override; - virtual torch::lazy::ComputationPtr CreateComputation(MlirOperation func_op); + virtual torch::lazy::ComputationPtr CreateComputation(MlirModule module_op); // Retrieves the lowered operation for an output. If the requested output is // not available yet, the graph behind the output's Node is lowered, and the @@ -123,7 +123,7 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation { using InputOutputAlias = TorchMlirLoweringContext::InputOutputAlias; TorchMlirComputation( - MlirOperation func_op, MlirContext mlir_context, + MlirModule module_op, MlirContext mlir_context, const std::shared_ptr& graph, std::unordered_map parameters_map, InputOutputAliases input_output_aliases); @@ -142,6 +142,8 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation { MlirOperation func_op() const; + MlirModule module_op() const; + MlirContext mlir_context() const; virtual const std::string debug_string() const; @@ -155,7 +157,7 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation { std::vector parameter_shapes_; Shape result_shape_; - MlirOperation func_op_; + MlirModule module_op_; MlirContext mlir_context_; std::shared_ptr graph_; InputOutputAliases input_output_aliases_; From dea5b8ac5b0d15925f7abadc4c8845ba77e4d819 Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Thu, 12 Jan 2023 12:46:51 -0800 Subject: [PATCH 2/5] Update VerifyBackendContract pass --- .../torch-mlir/Dialect/Torch/Transforms/Passes.h | 3 +-- .../torch-mlir/Dialect/Torch/Transforms/Passes.td | 9 +-------- .../Torch/Transforms/LowerToBackendContract.cpp | 15 +++++---------- 3 files changed, 7 insertions(+), 20 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 930e6fac11c0..36ea6de587b5 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -121,8 +121,7 @@ createLowerToBackendContractPass(int maxIterations, bool decompose, ArrayRef backendLegalOps); std::unique_ptr> -createVerifyBackendContractPass(bool decompose, - ArrayRef backendLegalOps); +createVerifyBackendContractPass(); StringRef getAbstractInterpLibrary(); diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 1400250f878d..9a123cd82d8f 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -347,20 +347,13 @@ def VerifyBackendContract : Pass<"torch-verify-backend-contract", "ModuleOp"> { let summary = "Check that program satisfies backend contract."; let constructor = [{ - mlir::torch::Torch::createVerifyBackendContractPass( - /*decompose=*/false, /*backendLegalOps=*/{}) + mlir::torch::Torch::createVerifyBackendContractPass() }]; let description = [{ This pass performs a set of inspections to check that program satisfies backend contract. In case of check failure it prints out the error message and returns `signalPassFailure()` status. }]; - let options = [ - Option<"decompose", "decompose", "bool", /*default=*/"false", - "Decompose ops.">, - ListOption<"backendLegalOps", "backend-legal-ops", "std::string", - "List of ops to be considered legal for the backend."> - ]; } #endif // TORCHMLIR_TORCH_PASSES diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index a2db26627ae6..b258b3cf34a3 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -289,15 +289,12 @@ class VerifyBackendContractPass : public VerifyBackendContractBase { public: VerifyBackendContractPass() = default; - VerifyBackendContractPass(bool decompose, - ArrayRef backendLegalOps) { - this->decompose = decompose; - this->backendLegalOps = backendLegalOps; - } + void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target = - getBackendContractTarget(context, decompose, backendLegalOps); + getBackendContractTarget(context, /*decompose*/false, + /*backendLegalOps*/{}); if (!satisfiesBackendContract(getOperation(), target, /*actuallyEmitDiagnostics=*/true)) { @@ -315,10 +312,8 @@ mlir::torch::Torch::createLowerToBackendContractPass( } std::unique_ptr> -mlir::torch::Torch::createVerifyBackendContractPass( - bool decompose, ArrayRef backendLegalOps) { - return std::make_unique(decompose, - backendLegalOps); +mlir::torch::Torch::createVerifyBackendContractPass() { + return std::make_unique(); } // The backend contract guarantees that ops with decompositions available will From d7ce4012e3f5649aadf1acd6995a33233ed14def Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Mon, 16 Jan 2023 09:41:46 -0800 Subject: [PATCH 3/5] Move convert_scalar_implicit to jit_utils --- .../csrc/base_lazy_backend/CMakeLists.txt | 1 + .../mlir_lowering_context.cpp | 2 +- .../base_lazy_backend/utils/jit_utils.cpp | 45 +++++++++++++++++++ .../csrc/base_lazy_backend/utils/jit_utils.h | 10 +++++ 4 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp create mode 100644 python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h diff --git a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt index 68a604e2801d..3293c6e2f663 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt +++ b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt @@ -71,6 +71,7 @@ add_library(torch_mlir_ltc_backend SHARED mlir_node.cpp ops/device_data.cpp ops/generic.cpp + utils/jit_utils.cpp utils/tensor_utils.cpp ) target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17) diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index 9ffb4cb1def7..af5001b6e463 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -14,7 +14,6 @@ #include #include -#include #include #include "torch-mlir-c/Registration.h" #include "torch-mlir-c/Transforms.h" @@ -27,6 +26,7 @@ #include "mlir_node.h" #include "utils/debug.h" #include "utils/exception.h" +#include "utils/jit_utils.h" #include "utils/string_utils.h" #include "utils/sys_utils.h" diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp b/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp new file mode 100644 index 000000000000..8d64f9fb7c9b --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp @@ -0,0 +1,45 @@ +#include "jit_utils.h" + +#include + +#include + +namespace torch { +namespace jit { + +void ConvertScalarImplicit(std::shared_ptr& graph) { + DepthFirstGraphNodeIterator it(graph); + for (auto* node = it.next(); node != nullptr; node = it.next()) { + if (node->kind() != c10::aten::ScalarImplicit) { + continue; + } + + auto input = node->input(0); + auto scalar_type = input->type()->cast()->scalarType(); + TORCH_CHECK(scalar_type, "scalar type is not defined for input value"); + + NodeKind node_type; + TypePtr output_type; + if (c10::isIntegralType(*scalar_type, false)) { + node_type = c10::aten::IntImplicit; + output_type = IntType::get(); + } else if (c10::isFloatingType(*scalar_type)) { + node_type = c10::aten::FloatImplicit; + output_type = FloatType::get(); + } else { + throw std::runtime_error( + "Expected isIntegralType or isFloatingType"); + } + + Value * output = graph + ->create(node_type, {input}) + ->insertBefore(node) + ->output() + ->setType(output_type); + node->output()->replaceAllUsesWith(output); + node->destroy(); + } +} + +} // namespace jit +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h new file mode 100644 index 000000000000..2c4214cfc1ab --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h @@ -0,0 +1,10 @@ +#include + +namespace torch { +namespace jit { + +// Convert ScalarImplicit to IntImplicit or FloatImplicit. +TORCH_API void ConvertScalarImplicit(std::shared_ptr& graph); + +} // namespace jit +} // namespace torch From bdddd7f15f481506537bf95ef54cc17aeb138ac9 Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Mon, 23 Jan 2023 10:02:09 -0800 Subject: [PATCH 4/5] Rename VerifyBackendContract to VerifyBackendContractNoDecompositions --- include/torch-mlir/Dialect/Torch/Transforms/Passes.h | 2 +- include/torch-mlir/Dialect/Torch/Transforms/Passes.td | 10 +++++----- .../Torch/Transforms/LowerToBackendContract.cpp | 10 +++++----- .../csrc/base_lazy_backend/mlir_lowering_context.cpp | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 36ea6de587b5..45cd888dc7f5 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -121,7 +121,7 @@ createLowerToBackendContractPass(int maxIterations, bool decompose, ArrayRef backendLegalOps); std::unique_ptr> -createVerifyBackendContractPass(); +createVerifyBackendContractNoDecompositionsPass(); StringRef getAbstractInterpLibrary(); diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 9a123cd82d8f..1ee87b36ec27 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -343,16 +343,16 @@ def LowerToBackendContract let dependentDialects = ["func::FuncDialect"]; } -def VerifyBackendContract - : Pass<"torch-verify-backend-contract", "ModuleOp"> { +def VerifyBackendContractNoDecompositions + : Pass<"torch-verify-backend-contract-no-decompositions", "ModuleOp"> { let summary = "Check that program satisfies backend contract."; let constructor = [{ - mlir::torch::Torch::createVerifyBackendContractPass() + mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() }]; let description = [{ This pass performs a set of inspections to check that program satisfies backend - contract. In case of check failure it prints out the error message and returns - `signalPassFailure()` status. + contract assuming that no decompositions were applied. In case of check failure + it prints out the error message and returns `signalPassFailure()` status. }]; } diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index b258b3cf34a3..bf9c98d517b0 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -285,10 +285,10 @@ class LowerToBackendContractPass } }; -class VerifyBackendContractPass - : public VerifyBackendContractBase { +class VerifyBackendContractNoDecompositionsPass + : public VerifyBackendContractNoDecompositionsBase { public: - VerifyBackendContractPass() = default; + VerifyBackendContractNoDecompositionsPass() = default; void runOnOperation() override { MLIRContext *context = &getContext(); @@ -312,8 +312,8 @@ mlir::torch::Torch::createLowerToBackendContractPass( } std::unique_ptr> -mlir::torch::Torch::createVerifyBackendContractPass() { - return std::make_unique(); +mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() { + return std::make_unique(); } // The backend contract guarantees that ops with decompositions available will diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index af5001b6e463..6010d3dd8718 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -162,7 +162,7 @@ ComputationPtr TorchMlirLoweringContext::Build() { auto pass_manager = mlirPassManagerCreate(mlir_context_); mlirPassManagerAddOwnedPass( pass_manager, - mlirCreateVerifyBackendContract() + mlirCreateVerifyBackendContractNoDecompositions() ); MlirLogicalResult result = mlirPassManagerRun( From 82488639ccb264750b23dc859be06741dad68ada Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Mon, 23 Jan 2023 10:28:36 -0800 Subject: [PATCH 5/5] Update verify-backend-contract-error.mlir test --- .../Dialect/Torch/verify-backend-contract-error.mlir | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/Dialect/Torch/verify-backend-contract-error.mlir b/test/Dialect/Torch/verify-backend-contract-error.mlir index 5accee12615a..eb9c6c581a99 100644 --- a/test/Dialect/Torch/verify-backend-contract-error.mlir +++ b/test/Dialect/Torch/verify-backend-contract-error.mlir @@ -1,7 +1,7 @@ -// RUN: torch-mlir-opt -torch-verify-backend-contract -split-input-file -verify-diagnostics %s -func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - // expected-error @+2 {{found an op that was marked as backend illegal}} - // expected-note @+1 {{this is likely due to}} - %t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> - return %t : !torch.vtensor<[?,?],f32> +// RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s +func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor { + // expected-error @below {{unsupported by backend contract: tensor with unknown rank}} + // expected-note @below {{this is likely due to a missing transfer function}} + %t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor + return %t : !torch.vtensor }