diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 930e6fac11c0..45cd888dc7f5 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -121,8 +121,7 @@ createLowerToBackendContractPass(int maxIterations, bool decompose, ArrayRef backendLegalOps); std::unique_ptr> -createVerifyBackendContractPass(bool decompose, - ArrayRef backendLegalOps); +createVerifyBackendContractNoDecompositionsPass(); StringRef getAbstractInterpLibrary(); diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 5dcf2286b4cc..1ee87b36ec27 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -343,24 +343,17 @@ def LowerToBackendContract let dependentDialects = ["func::FuncDialect"]; } -def VerifyBackendContract - : Pass<"torch-verify-backend-contract", "ModuleOp"> { +def VerifyBackendContractNoDecompositions + : Pass<"torch-verify-backend-contract-no-decompositions", "ModuleOp"> { let summary = "Check that program satisfies backend contract."; let constructor = [{ - mlir::torch::Torch::createVerifyBackendContractPass( - /*decompose=*/true, /*backendLegalOps=*/{}) + mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() }]; let description = [{ This pass performs a set of inspections to check that program satisfies backend - contract. In case of check failure it prints out the error message and returns - `signalPassFailure()` status. + contract assuming that no decompositions were applied. In case of check failure + it prints out the error message and returns `signalPassFailure()` status. }]; - let options = [ - Option<"decompose", "decompose", "bool", /*default=*/"true", - "Decompose ops.">, - ListOption<"backendLegalOps", "backend-legal-ops", "std::string", - "List of ops to be considered legal for the backend."> - ]; } #endif // TORCHMLIR_TORCH_PASSES diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index a2db26627ae6..bf9c98d517b0 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -285,19 +285,16 @@ class LowerToBackendContractPass } }; -class VerifyBackendContractPass - : public VerifyBackendContractBase { +class VerifyBackendContractNoDecompositionsPass + : public VerifyBackendContractNoDecompositionsBase { public: - VerifyBackendContractPass() = default; - VerifyBackendContractPass(bool decompose, - ArrayRef backendLegalOps) { - this->decompose = decompose; - this->backendLegalOps = backendLegalOps; - } + VerifyBackendContractNoDecompositionsPass() = default; + void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target = - getBackendContractTarget(context, decompose, backendLegalOps); + getBackendContractTarget(context, /*decompose*/false, + /*backendLegalOps*/{}); if (!satisfiesBackendContract(getOperation(), target, /*actuallyEmitDiagnostics=*/true)) { @@ -315,10 +312,8 @@ mlir::torch::Torch::createLowerToBackendContractPass( } std::unique_ptr> -mlir::torch::Torch::createVerifyBackendContractPass( - bool decompose, ArrayRef backendLegalOps) { - return std::make_unique(decompose, - backendLegalOps); +mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() { + return std::make_unique(); } // The backend contract guarantees that ops with decompositions available will diff --git a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt index 68a604e2801d..3293c6e2f663 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt +++ b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt @@ -71,6 +71,7 @@ add_library(torch_mlir_ltc_backend SHARED mlir_node.cpp ops/device_data.cpp ops/generic.cpp + utils/jit_utils.cpp utils/tensor_utils.cpp ) target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17) diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index ec234dc77e02..6010d3dd8718 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -16,6 +16,9 @@ #include #include #include "torch-mlir-c/Registration.h" +#include "torch-mlir-c/Transforms.h" +#include "mlir-c/IR.h" +#include "mlir-c/Pass.h" #include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h" #include "backend_impl.h" @@ -23,6 +26,7 @@ #include "mlir_node.h" #include "utils/debug.h" #include "utils/exception.h" +#include "utils/jit_utils.h" #include "utils/string_utils.h" #include "utils/sys_utils.h" @@ -135,6 +139,11 @@ ComputationPtr TorchMlirLoweringContext::Build() { graph_->block()->registerOutput(output); } + // During operations lowering JIT may insert ScalarImplicit ops which output + // type !torch.number doesn't represent any existing MLIR type and should be + // refined either to Torch::IntType or Torch::FloatType. + torch::jit::ConvertScalarImplicit(graph_); + // Generate MLIR. MlirOperation func_op = torch_mlir::importJitFunctionAsFuncOp( /*context=*/mlir_context_, @@ -142,12 +151,35 @@ ComputationPtr TorchMlirLoweringContext::Build() { /*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; }, /*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true}); - return CreateComputation(func_op); + + // Convert MlirOperation to MlirModule. + MlirLocation loc = mlirLocationUnknownGet(mlir_context_); + MlirModule module_op = mlirModuleCreateEmpty(loc); + MlirBlock block = mlirModuleGetBody(module_op); + mlirBlockAppendOwnedOperation(block, func_op); + + // Apply passes to verify generated MLIR. + auto pass_manager = mlirPassManagerCreate(mlir_context_); + mlirPassManagerAddOwnedPass( + pass_manager, + mlirCreateVerifyBackendContractNoDecompositions() + ); + + MlirLogicalResult result = mlirPassManagerRun( + pass_manager, + module_op + ); + + if (mlirLogicalResultIsFailure(result)) { + throw std::runtime_error("MLIR verification has failed."); + } + + return CreateComputation(module_op); } -ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirOperation func_op) { +ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirModule module_op) { return std::make_shared( - func_op, mlir_context_, graph_, parameter_names_, input_output_aliases_); + module_op, mlir_context_, graph_, parameter_names_, input_output_aliases_); } torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) { @@ -295,11 +327,11 @@ void TorchMlirLoweringContext::RegisterMlirDialects() { /////////////////////////////////////////////////////////////////////////////// TorchMlirComputation::TorchMlirComputation( - MlirOperation func_op, MlirContext mlir_context, + MlirModule module_op, MlirContext mlir_context, const std::shared_ptr& graph, std::unordered_map parameters_map, InputOutputAliases input_output_aliases) - : func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)), + : module_op_(std::move(module_op)), mlir_context_(std::move(mlir_context)), graph_(graph), input_output_aliases_(input_output_aliases), parameters_map_(parameters_map) { @@ -340,7 +372,14 @@ std::shared_ptr TorchMlirComputation::graph() const { return graph_; } -MlirOperation TorchMlirComputation::func_op() const { return func_op_; } +MlirOperation TorchMlirComputation::func_op() const { + MlirBlock block = mlirModuleGetBody(module_op_); + return mlirBlockGetFirstOperation(block); +} + +MlirModule TorchMlirComputation::module_op() const { + return module_op_; +} MlirContext TorchMlirComputation::mlir_context() const { return mlir_context_; @@ -385,7 +424,7 @@ const std::string TorchMlirComputation::to_string() const { *ss_ptr << std::string(part.data, part.length); }; std::stringstream ss; - mlirOperationPrint(func_op_, print_callback, &ss); + mlirOperationPrint(mlirModuleGetOperation(module_op_), print_callback, &ss); return ss.str(); } diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h index 61e18f4106c9..f62a71ce7945 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h @@ -73,7 +73,7 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { // embedded builder (returned by the builder() API). torch::lazy::ComputationPtr Build() override; - virtual torch::lazy::ComputationPtr CreateComputation(MlirOperation func_op); + virtual torch::lazy::ComputationPtr CreateComputation(MlirModule module_op); // Retrieves the lowered operation for an output. If the requested output is // not available yet, the graph behind the output's Node is lowered, and the @@ -123,7 +123,7 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation { using InputOutputAlias = TorchMlirLoweringContext::InputOutputAlias; TorchMlirComputation( - MlirOperation func_op, MlirContext mlir_context, + MlirModule module_op, MlirContext mlir_context, const std::shared_ptr& graph, std::unordered_map parameters_map, InputOutputAliases input_output_aliases); @@ -142,6 +142,8 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation { MlirOperation func_op() const; + MlirModule module_op() const; + MlirContext mlir_context() const; virtual const std::string debug_string() const; @@ -155,7 +157,7 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation { std::vector parameter_shapes_; Shape result_shape_; - MlirOperation func_op_; + MlirModule module_op_; MlirContext mlir_context_; std::shared_ptr graph_; InputOutputAliases input_output_aliases_; diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp b/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp new file mode 100644 index 000000000000..8d64f9fb7c9b --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp @@ -0,0 +1,45 @@ +#include "jit_utils.h" + +#include + +#include + +namespace torch { +namespace jit { + +void ConvertScalarImplicit(std::shared_ptr& graph) { + DepthFirstGraphNodeIterator it(graph); + for (auto* node = it.next(); node != nullptr; node = it.next()) { + if (node->kind() != c10::aten::ScalarImplicit) { + continue; + } + + auto input = node->input(0); + auto scalar_type = input->type()->cast()->scalarType(); + TORCH_CHECK(scalar_type, "scalar type is not defined for input value"); + + NodeKind node_type; + TypePtr output_type; + if (c10::isIntegralType(*scalar_type, false)) { + node_type = c10::aten::IntImplicit; + output_type = IntType::get(); + } else if (c10::isFloatingType(*scalar_type)) { + node_type = c10::aten::FloatImplicit; + output_type = FloatType::get(); + } else { + throw std::runtime_error( + "Expected isIntegralType or isFloatingType"); + } + + Value * output = graph + ->create(node_type, {input}) + ->insertBefore(node) + ->output() + ->setType(output_type); + node->output()->replaceAllUsesWith(output); + node->destroy(); + } +} + +} // namespace jit +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h new file mode 100644 index 000000000000..2c4214cfc1ab --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h @@ -0,0 +1,10 @@ +#include + +namespace torch { +namespace jit { + +// Convert ScalarImplicit to IntImplicit or FloatImplicit. +TORCH_API void ConvertScalarImplicit(std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/test/Dialect/Torch/verify-backend-contract-error.mlir b/test/Dialect/Torch/verify-backend-contract-error.mlir index 5accee12615a..eb9c6c581a99 100644 --- a/test/Dialect/Torch/verify-backend-contract-error.mlir +++ b/test/Dialect/Torch/verify-backend-contract-error.mlir @@ -1,7 +1,7 @@ -// RUN: torch-mlir-opt -torch-verify-backend-contract -split-input-file -verify-diagnostics %s -func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - // expected-error @+2 {{found an op that was marked as backend illegal}} - // expected-note @+1 {{this is likely due to}} - %t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> - return %t : !torch.vtensor<[?,?],f32> +// RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s +func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor { + // expected-error @below {{unsupported by backend contract: tensor with unknown rank}} + // expected-note @below {{this is likely due to a missing transfer function}} + %t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor + return %t : !torch.vtensor }