From 4ec00fb3eafa885da6d305ebdf1361d4be54dedf Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 20 Feb 2022 05:49:33 +0900 Subject: [PATCH] [mlir][bufferize] Add a way for ops to fail the analysis Add `BufferizableOpInterface::verifyAnalysis`. Ops can implement this method to check for expected invariants and limitations. The purpose of this change is to introduce a modular way of checking assertions such as `assertScfForAliasingProperties`. Differential Revision: https://reviews.llvm.org/D120189 --- .../IR/BufferizableOpInterface.td | 17 +++++ .../Dialect/SCF/BufferizableOpInterfaceImpl.h | 10 --- .../Transforms/OneShotAnalysis.cpp | 13 ++++ .../Transforms/ComprehensiveBufferizePass.cpp | 3 - .../BufferizableOpInterfaceImpl.cpp | 66 +++++++++---------- ...omprehensive-module-bufferize-invalid.mlir | 6 +- ...omprehensive-module-bufferize-partial.mlir | 29 -------- .../Linalg/TestComprehensiveBufferize.cpp | 4 -- 8 files changed, 65 insertions(+), 83 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td index f6c51dae92eaf..ac26d327d2e31 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -290,6 +290,23 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { /*defaultImplementation=*/[{ return false; }] + >, + InterfaceMethod< + /*desc=*/[{ + Return `failure` if this op does not pass the analysis. This method + is run during One-Shot Bufferize (after all post-analysis steps). If + the op does not pass the analysis, bufferization is aborted. + + This method can be used to check expected invariants and limitations + of the current bufferization implementation. + }], + /*retType=*/"LogicalResult", + /*methodName=*/"verifyAnalysis", + /*args=*/(ins "const BufferizationState &":$state), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return success(); + }] > ]; diff --git a/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h index dfeb9514409fb..08c6ca2ee0d29 100644 --- a/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h @@ -20,16 +20,6 @@ class BufferizationAliasInfo; } // namespace bufferization namespace scf { -/// Assert that yielded values of an scf.for op are aliasing their corresponding -/// bbArgs. This is required because the i-th OpResult of an scf.for op is -/// currently assumed to alias with the i-th iter_arg (in the absence of -/// conflicts). -LogicalResult -assertScfForAliasingProperties(Operation *op, - bufferization::BufferizationState &state, - bufferization::BufferizationAliasInfo &aliasInfo, - SmallVector &newOps); - void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index 78e3ac8aba7c3..6232e9ae7cba0 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -778,6 +778,19 @@ LogicalResult bufferization::analyzeOp(Operation *op, return failure(); } + // Analysis verification: After setting up alias/equivalence sets, each op + // can check for expected invariants/limitations and fail the analysis if + // necessary. + bool passedAnalysis = true; + op->walk([&](Operation *op) { + if (BufferizableOpInterface bufferizableOp = + options.dynCastBufferizableOp(op)) + if (failed(bufferizableOp.verifyAnalysis(state))) + passedAnalysis = false; + }); + if (!passedAnalysis) + return failure(); + // Annotate operations if we only want to report the analysis. if (options.testAnalysisOnly) annotateOpsWithBufferizationMarkers(op, aliasInfo, state); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp index cd71264064168..b4ac512463cb5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -105,9 +105,6 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() { opt = *options; } - // Only certain scf.for ops are supported by the analysis. - opt.addPostAnalysisStep(scf::assertScfForAliasingProperties); - ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 83a70e8dcf3af..d4dd3489841c6 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -385,6 +385,37 @@ struct ForOpInterface return success(); } + + /// Assert that yielded values of an scf.for op are aliasing with their + /// corresponding bbArgs. This is required because the i-th OpResult of an + /// scf.for op is currently assumed to alias with the i-th iter_arg (in the + /// absence of conflicts). + LogicalResult verifyAnalysis(Operation *op, + const BufferizationState &state) const { + auto forOp = cast(op); + auto yieldOp = + cast(forOp.getLoopBody().front().getTerminator()); + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast(); + if (!tensorType) + continue; + + OpOperand &forOperand = forOp.getOpOperandForResult( + forOp->getResult(operand.getOperandNumber())); + auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); + // Note: This is overly strict. We should check for aliasing bufferized + // values. But we don't have a "must-alias" analysis yet. + if (!state.areEquivalentBufferizedValues(operand.get(), bbArg)) + // TODO: this could get resolved with copies but it can also turn into + // swaps so we need to be careful about order of copies. + return yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to a buffer that is aliasing the " + "matching" + << " enclosing scf::for operand"; + } + return success(); + } }; /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so @@ -434,41 +465,6 @@ struct YieldOpInterface } // namespace scf } // namespace mlir -LogicalResult mlir::scf::assertScfForAliasingProperties( - Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) { - LogicalResult status = success(); - - op->walk([&](scf::ForOp forOp) { - auto yieldOp = - cast(forOp.getLoopBody().front().getTerminator()); - for (OpOperand &operand : yieldOp->getOpOperands()) { - auto tensorType = operand.get().getType().dyn_cast(); - if (!tensorType) - continue; - - OpOperand &forOperand = forOp.getOpOperandForResult( - forOp->getResult(operand.getOperandNumber())); - auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - // Note: This is overly strict. We should check for aliasing bufferized - // values. But we don't have a "must-alias" analysis yet. - if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) { - // TODO: this could get resolved with copies but it can also turn into - // swaps so we need to be careful about order of copies. - status = - yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to a buffer that is aliasing the matching" - << " enclosing scf::for operand"; - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); - - return status; -} - void mlir::scf::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addOpInterface(); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir index d1791da1646bf..2adf2aadc2d93 100644 --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -87,7 +87,7 @@ func @scf_for(%A : tensor, %B : tensor {linalg.inplaceable = true}, %C : tensor<4xf32>, %lb : index, %ub : index, %step : index) - -> (tensor, tensor) + -> (f32, f32) { %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B) -> (tensor, tensor) @@ -102,7 +102,9 @@ func @scf_for(%A : tensor, scf.yield %ttB, %ttA : tensor, tensor } - return %r0#0, %r0#1: tensor, tensor + %f0 = tensor.extract %r0#0[%step] : tensor + %f1 = tensor.extract %r0#1[%step] : tensor + return %f0, %f1: f32, f32 } // ----- diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir index ea1251fc080b2..0ea8b59adb9ef 100644 --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir @@ -179,35 +179,6 @@ func @simple_tensor_test(%t1 : tensor, %f : f32) -> tensor { // ----- -// CHECK-SCF-LABEL: func @simple_scf_for( -// CHECK-SCF-SAME: %[[t1:.*]]: tensor -func @simple_scf_for( - %t1: tensor, %sz: index, %step: index, %f: f32) -> tensor { - %c0 = arith.constant 0 : index - - // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] - // CHECK-SCF: %[[alloc:.*]] = memref.alloc - // CHECK-SCF: %[[casted:.*]] = memref.cast %[[alloc]] - // CHECK-SCF: memref.copy %[[t1_memref]], %[[alloc]] - // CHECK-SCF: %[[scf_for:.*]] = scf.for %[[iv:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[arg0:.*]] = %[[casted]]) -> ({{.*}}) { - %0 = scf.for %iv = %c0 to %sz step %step iter_args(%arg0 = %t1) -> tensor { - // CHECK-SCF: %[[arg0_tensor:.*]] = bufferization.to_tensor %[[arg0]] - // CHECK-SCF: %[[insert:.*]] = tensor.insert %{{.*}} into %[[arg0_tensor]] - %1 = tensor.insert %f into %arg0[%iv] : tensor - - // CHECK-SCF: %[[insert_memref:.*]] = bufferization.to_memref %[[insert]] - // CHECK-SCF: scf.yield %[[insert_memref]] - scf.yield %1 : tensor - } - // CHECK-SCF: } - - // CHECK-SCF: %[[scf_for_tensor:.*]] = bufferization.to_tensor %[[scf_for]] - // CHECK-SCF: return %[[scf_for_tensor]] - return %0 : tensor -} - -// ----- - // CHECK-SCF-LABEL: func @simple_scf_if( // CHECK-SCF-SAME: %[[t1:.*]]: tensor {linalg.inplaceable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index func @simple_scf_if(%t1: tensor {linalg.inplaceable = true}, %c: i1, %pos: index, %f: f32) diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp index f0b6b0e669ec4..9eb68343eaadf 100644 --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -102,10 +102,6 @@ struct TestComprehensiveFunctionBufferize void TestComprehensiveFunctionBufferize::runOnOperation() { auto options = std::make_unique(); - - if (!allowReturnMemref) - options->addPostAnalysisStep(scf::assertScfForAliasingProperties); - options->allowReturnMemref = allowReturnMemref; options->allowUnknownOps = allowUnknownOps; options->testAnalysisOnly = testAnalysisOnly;