Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ createLowerToBackendContractPass(int maxIterations, bool decompose,
ArrayRef<std::string> backendLegalOps);

std::unique_ptr<OperationPass<ModuleOp>>
createVerifyBackendContractPass(bool decompose,
ArrayRef<std::string> backendLegalOps);
createVerifyBackendContractNoDecompositionsPass();

StringRef getAbstractInterpLibrary();

Expand Down
17 changes: 5 additions & 12 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -343,24 +343,17 @@ def LowerToBackendContract
let dependentDialects = ["func::FuncDialect"];
}

def VerifyBackendContract
: Pass<"torch-verify-backend-contract", "ModuleOp"> {
def VerifyBackendContractNoDecompositions
: Pass<"torch-verify-backend-contract-no-decompositions", "ModuleOp"> {
let summary = "Check that program satisfies backend contract.";
let constructor = [{
mlir::torch::Torch::createVerifyBackendContractPass(
/*decompose=*/true, /*backendLegalOps=*/{})
mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass()
}];
let description = [{
This pass performs a set of inspections to check that program satisfies backend
contract. In case of check failure it prints out the error message and returns
`signalPassFailure()` status.
contract assuming that no decompositions were applied. In case of check failure
it prints out the error message and returns `signalPassFailure()` status.
}];
let options = [
Option<"decompose", "decompose", "bool", /*default=*/"true",
"Decompose ops.">,
ListOption<"backendLegalOps", "backend-legal-ops", "std::string",
"List of ops to be considered legal for the backend.">
];
}

#endif // TORCHMLIR_TORCH_PASSES
21 changes: 8 additions & 13 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,19 +285,16 @@ class LowerToBackendContractPass
}
};

class VerifyBackendContractPass
: public VerifyBackendContractBase<VerifyBackendContractPass> {
class VerifyBackendContractNoDecompositionsPass
: public VerifyBackendContractNoDecompositionsBase<VerifyBackendContractNoDecompositionsPass> {
public:
VerifyBackendContractPass() = default;
VerifyBackendContractPass(bool decompose,
ArrayRef<std::string> backendLegalOps) {
this->decompose = decompose;
this->backendLegalOps = backendLegalOps;
}
VerifyBackendContractNoDecompositionsPass() = default;

void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target =
getBackendContractTarget(context, decompose, backendLegalOps);
getBackendContractTarget(context, /*decompose*/false,
/*backendLegalOps*/{});

if (!satisfiesBackendContract(getOperation(), target,
/*actuallyEmitDiagnostics=*/true)) {
Expand All @@ -315,10 +312,8 @@ mlir::torch::Torch::createLowerToBackendContractPass(
}

std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createVerifyBackendContractPass(
bool decompose, ArrayRef<std::string> backendLegalOps) {
return std::make_unique<VerifyBackendContractPass>(decompose,
backendLegalOps);
mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
return std::make_unique<VerifyBackendContractNoDecompositionsPass>();
}

// The backend contract guarantees that ops with decompositions available will
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ add_library(torch_mlir_ltc_backend SHARED
mlir_node.cpp
ops/device_data.cpp
ops/generic.cpp
utils/jit_utils.cpp
utils/tensor_utils.cpp
)
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)
Expand Down
53 changes: 46 additions & 7 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include "torch-mlir-c/Registration.h"
#include "torch-mlir-c/Transforms.h"
#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"

#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
#include "backend_impl.h"
#include "mlir_lowering_context.h"
#include "mlir_node.h"
#include "utils/debug.h"
#include "utils/exception.h"
#include "utils/jit_utils.h"
#include "utils/string_utils.h"
#include "utils/sys_utils.h"

Expand Down Expand Up @@ -135,19 +139,47 @@ ComputationPtr TorchMlirLoweringContext::Build() {
graph_->block()->registerOutput(output);
}

// During operations lowering JIT may insert ScalarImplicit ops which output
// type !torch.number doesn't represent any existing MLIR type and should be
// refined either to Torch::IntType or Torch::FloatType.
torch::jit::ConvertScalarImplicit(graph_);

// Generate MLIR.
MlirOperation func_op = torch_mlir::importJitFunctionAsFuncOp(
/*context=*/mlir_context_,
/*function=*/generate_jit_fn().get(),
/*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; },
/*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true});

return CreateComputation(func_op);

// Convert MlirOperation to MlirModule.
MlirLocation loc = mlirLocationUnknownGet(mlir_context_);
MlirModule module_op = mlirModuleCreateEmpty(loc);
MlirBlock block = mlirModuleGetBody(module_op);
mlirBlockAppendOwnedOperation(block, func_op);

// Apply passes to verify generated MLIR.
auto pass_manager = mlirPassManagerCreate(mlir_context_);
mlirPassManagerAddOwnedPass(
pass_manager,
mlirCreateVerifyBackendContractNoDecompositions()
);

MlirLogicalResult result = mlirPassManagerRun(
pass_manager,
module_op
);

if (mlirLogicalResultIsFailure(result)) {
throw std::runtime_error("MLIR verification has failed.");
}

return CreateComputation(module_op);
}

ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirOperation func_op) {
ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirModule module_op) {
return std::make_shared<TorchMlirComputation>(
func_op, mlir_context_, graph_, parameter_names_, input_output_aliases_);
module_op, mlir_context_, graph_, parameter_names_, input_output_aliases_);
}

torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
Expand Down Expand Up @@ -295,11 +327,11 @@ void TorchMlirLoweringContext::RegisterMlirDialects() {
///////////////////////////////////////////////////////////////////////////////

TorchMlirComputation::TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
MlirModule module_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph,
std::unordered_map<int, std::string> parameters_map,
InputOutputAliases input_output_aliases)
: func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)),
: module_op_(std::move(module_op)), mlir_context_(std::move(mlir_context)),
graph_(graph), input_output_aliases_(input_output_aliases),
parameters_map_(parameters_map) {

Expand Down Expand Up @@ -340,7 +372,14 @@ std::shared_ptr<torch::jit::Graph> TorchMlirComputation::graph() const {
return graph_;
}

MlirOperation TorchMlirComputation::func_op() const { return func_op_; }
MlirOperation TorchMlirComputation::func_op() const {
MlirBlock block = mlirModuleGetBody(module_op_);
return mlirBlockGetFirstOperation(block);
}

MlirModule TorchMlirComputation::module_op() const {
return module_op_;
}

MlirContext TorchMlirComputation::mlir_context() const {
return mlir_context_;
Expand Down Expand Up @@ -385,7 +424,7 @@ const std::string TorchMlirComputation::to_string() const {
*ss_ptr << std::string(part.data, part.length);
};
std::stringstream ss;
mlirOperationPrint(func_op_, print_callback, &ss);
mlirOperationPrint(mlirModuleGetOperation(module_op_), print_callback, &ss);
return ss.str();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
// embedded builder (returned by the builder() API).
torch::lazy::ComputationPtr Build() override;

virtual torch::lazy::ComputationPtr CreateComputation(MlirOperation func_op);
virtual torch::lazy::ComputationPtr CreateComputation(MlirModule module_op);

// Retrieves the lowered operation for an output. If the requested output is
// not available yet, the graph behind the output's Node is lowered, and the
Expand Down Expand Up @@ -123,7 +123,7 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
using InputOutputAlias = TorchMlirLoweringContext::InputOutputAlias;

TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
MlirModule module_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph,
std::unordered_map<int, std::string> parameters_map,
InputOutputAliases input_output_aliases);
Expand All @@ -142,6 +142,8 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation {

MlirOperation func_op() const;

MlirModule module_op() const;

MlirContext mlir_context() const;

virtual const std::string debug_string() const;
Expand All @@ -155,7 +157,7 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
std::vector<Shape> parameter_shapes_;
Shape result_shape_;

MlirOperation func_op_;
MlirModule module_op_;
MlirContext mlir_context_;
std::shared_ptr<torch::jit::Graph> graph_;
InputOutputAliases input_output_aliases_;
Expand Down
45 changes: 45 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "jit_utils.h"

#include <torch/csrc/jit/runtime/graph_iterator.h>

#include <ATen/core/type_factory.h>

namespace torch {
namespace jit {

void ConvertScalarImplicit(std::shared_ptr<Graph>& graph) {
DepthFirstGraphNodeIterator it(graph);
for (auto* node = it.next(); node != nullptr; node = it.next()) {
if (node->kind() != c10::aten::ScalarImplicit) {
continue;
}

auto input = node->input(0);
auto scalar_type = input->type()->cast<c10::TensorType>()->scalarType();
TORCH_CHECK(scalar_type, "scalar type is not defined for input value");

NodeKind node_type;
TypePtr output_type;
if (c10::isIntegralType(*scalar_type, false)) {
node_type = c10::aten::IntImplicit;
output_type = IntType::get();
} else if (c10::isFloatingType(*scalar_type)) {
node_type = c10::aten::FloatImplicit;
output_type = FloatType::get();
} else {
throw std::runtime_error(
"Expected isIntegralType or isFloatingType");
}

Value * output = graph
->create(node_type, {input})
->insertBefore(node)
->output()
->setType(output_type);
node->output()->replaceAllUsesWith(output);
node->destroy();
}
}

} // namespace jit
} // namespace torch
10 changes: 10 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

// Convert ScalarImplicit to IntImplicit or FloatImplicit.
TORCH_API void ConvertScalarImplicit(std::shared_ptr<Graph>& graph);

} // namespace jit
} // namespace torch
12 changes: 6 additions & 6 deletions test/Dialect/Torch/verify-backend-contract-error.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: torch-mlir-opt -torch-verify-backend-contract -split-input-file -verify-diagnostics %s
func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// expected-error @+2 {{found an op that was marked as backend illegal}}
// expected-note @+1 {{this is likely due to}}
%t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %t : !torch.vtensor<[?,?],f32>
// RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s
func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
// expected-error @below {{unsupported by backend contract: tensor with unknown rank}}
// expected-note @below {{this is likely due to a missing transfer function}}
%t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor
return %t : !torch.vtensor
}