Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ std::unique_ptr<OperationPass<ModuleOp>>
createLowerToBackendContractPass(int maxIterations, bool decompose,
ArrayRef<std::string> backendLegalOps);

std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass();
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyBackendContractPass(bool decompose,
ArrayRef<std::string> backendLegalOps);

StringRef getAbstractInterpLibrary();

Expand Down
12 changes: 10 additions & 2 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,21 @@ def LowerToBackendContract
def VerifyBackendContract
: Pass<"torch-verify-backend-contract", "ModuleOp"> {
let summary = "Check that program satisfies backend contract.";
let constructor =
"mlir::torch::Torch::createVerifyBackendContractPass()";
let constructor = [{
mlir::torch::Torch::createVerifyBackendContractPass(
/*decompose=*/true, /*backendLegalOps=*/{})
}];
let description = [{
This pass performs a set of inspections to check that program satisfies backend
contract. In case of check failure it prints out the error message and returns
`signalPassFailure()` status.
}];
let options = [
Option<"decompose", "decompose", "bool", /*default=*/"true",
"Decompose ops.">,
ListOption<"backendLegalOps", "backend-legal-ops", "std::string",
"List of ops to be considered legal for the backend.">
];
}

#endif // TORCHMLIR_TORCH_PASSES
33 changes: 26 additions & 7 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,17 @@ static bool satisfiesBackendContract(ModuleOp module,
return true;
}

// Explicitly set ops and dialects allowed and not allowed in backend contract.
static ConversionTarget
getBackendContractTarget(MLIRContext *context, bool decompose,
ArrayRef<std::string> backendLegalOps) {
ConversionTarget target(*context);
target.addLegalDialect<func::FuncDialect, Torch::TorchDialect>();
if (decompose)
markDecomposedOpsAsIllegal(context, target, backendLegalOps);
return target;
}

namespace {
class LowerToBackendContractPass
: public LowerToBackendContractBase<LowerToBackendContractPass> {
Expand All @@ -239,10 +250,8 @@ class LowerToBackendContractPass
void runOnOperation() override {
ModuleOp module = getOperation();
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<func::FuncDialect, Torch::TorchDialect>();
if (decompose)
markDecomposedOpsAsIllegal(context, target, backendLegalOps);
ConversionTarget target =
getBackendContractTarget(context, decompose, backendLegalOps);

OpPassManager pm(module.getOperationName());
TorchLoweringPipelineOptions options;
Expand Down Expand Up @@ -279,9 +288,17 @@ class LowerToBackendContractPass
class VerifyBackendContractPass
: public VerifyBackendContractBase<VerifyBackendContractPass> {
public:
VerifyBackendContractPass() = default;
VerifyBackendContractPass(bool decompose,
ArrayRef<std::string> backendLegalOps) {
this->decompose = decompose;
this->backendLegalOps = backendLegalOps;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
ConversionTarget target =
getBackendContractTarget(context, decompose, backendLegalOps);

if (!satisfiesBackendContract(getOperation(), target,
/*actuallyEmitDiagnostics=*/true)) {
return signalPassFailure();
Expand All @@ -298,8 +315,10 @@ mlir::torch::Torch::createLowerToBackendContractPass(
}

std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createVerifyBackendContractPass() {
return std::make_unique<VerifyBackendContractPass>();
mlir::torch::Torch::createVerifyBackendContractPass(
bool decompose, ArrayRef<std::string> backendLegalOps) {
return std::make_unique<VerifyBackendContractPass>(decompose,
backendLegalOps);
}

// The backend contract guarantees that ops with decompositions available will
Expand Down
7 changes: 7 additions & 0 deletions test/Dialect/Torch/verify-backend-contract-error.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// RUN: torch-mlir-opt -torch-verify-backend-contract -split-input-file -verify-diagnostics %s
func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// expected-error @+2 {{found an op that was marked as backend illegal}}
// expected-note @+1 {{this is likely due to}}
%t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %t : !torch.vtensor<[?,?],f32>
}