Skip to content

Commit

Permalink
[mlir][bufferization][NFC] Bufferize with PostOrder traversal
Browse files Browse the repository at this point in the history
This is useful because the result type of an op can sometimes be inferred from its body (e.g., `scf.if`). This will be utilized in subsequent changes.

Also introduces a new `getBufferType` interface method on BufferizableOpInterface. This method is useful for computing a bufferized block argument type with respect to OpOperand types of the parent op.

Differential Revision: https://reviews.llvm.org/D128420
  • Loading branch information
matthias-springer committed Jun 27, 2022
1 parent 5830da1 commit ba9d886
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 11 deletions.
Expand Up @@ -337,7 +337,24 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*defaultImplementation=*/[{
return success();
}]
>
>,
InterfaceMethod<
/*desc=*/[{
Return the bufferized type of the given tensor block argument. The
block argument is guaranteed to belong to a block of this op.
}],
/*retType=*/"BaseMemRefType",
/*methodName=*/"getBufferType",
/*args=*/(ins "BlockArgument":$bbArg,
"const BufferizationOptions &":$options),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(bbArg.getOwner()->getParentOp() == $_op &&
"bbArg must belong to this op");
auto tensorType = bbArg.getType().cast<TensorType>();
return bufferization::getMemRefType(tensorType, options);
}]
>,
];

let extraClassDeclaration = [{
Expand Down
Expand Up @@ -482,8 +482,10 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {

Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
#ifndef NDEBUG
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
#endif // NDEBUG

// Replace "%t = to_tensor %m" with %m.
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
Expand All @@ -492,7 +494,7 @@ Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
Type memrefType = getMemRefType(tensorType, options);
Type memrefType = getBufferType(value, options);
ensureToMemrefOpIsValid(value, memrefType);
return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
value);
Expand All @@ -507,6 +509,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options) {
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.getMemref().getType().cast<BaseMemRefType>();

if (auto bbArg = value.dyn_cast<BlockArgument>())
if (auto bufferizableOp =
options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
return bufferizableOp.getBufferType(bbArg, options);

return getMemRefType(tensorType, options);
}

Expand Down
11 changes: 9 additions & 2 deletions mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Expand Up @@ -393,9 +393,16 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
// Otherwise, we have to use a memref type with a fully dynamic layout map to
// avoid copies. We are currently missing patterns for layout maps to
// canonicalize away (or canonicalize to more precise layouts).
//
// FuncOps must be bufferized before their bodies, so add them to the worklist
// first.
SmallVector<Operation *> worklist;
op->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (hasTensorSemantics(op))
op->walk([&](func::FuncOp funcOp) {
if (hasTensorSemantics(funcOp))
worklist.push_back(funcOp);
});
op->walk<WalkOrder::PostOrder>([&](Operation *op) {
if (hasTensorSemantics(op) && !isa<func::FuncOp>(op))
worklist.push_back(op);
});

Expand Down
Expand Up @@ -725,7 +725,7 @@ struct WhileOpInterface
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
return getBufferType(bbArg, options).cast<Type>();
return bufferization::getBufferType(bbArg, options).cast<Type>();
}));

// Construct a new scf.while op with memref instead of tensor values.
Expand Down Expand Up @@ -1107,7 +1107,7 @@ struct ParallelInsertSliceOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &b,
const BufferizationOptions &options) const {
// Will be bufferized as part of ForeachThreadOp.
return failure();
return success();
}

// TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
Expand Down
Expand Up @@ -154,7 +154,7 @@ struct AssumingYieldOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
// Op is bufferized as part of AssumingOp.
return failure();
return success();
}
};

Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Expand Up @@ -313,10 +313,10 @@ func.func @scf_for_swapping_yields(
// CHECK: %[[alloc2:.*]] = memref.alloc(%{{.*}})
// CHECK: memref.copy %[[iter2]], %[[alloc2]]
// CHECK: memref.dealloc %[[iter2]]
// CHECK: %[[casted2:.*]] = memref.cast %[[alloc2]]
// CHECK: %[[alloc1:.*]] = memref.alloc(%{{.*}})
// CHECK: memref.copy %[[iter1]], %[[alloc1]]
// CHECK: memref.dealloc %[[iter1]]
// CHECK: %[[casted2:.*]] = memref.cast %[[alloc2]]
// CHECK: %[[casted1:.*]] = memref.cast %[[alloc1]]
// CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]]
// CHECK: memref.dealloc %[[alloc1]]
Expand Down Expand Up @@ -384,10 +384,10 @@ func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>,
// CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w1]], %[[a1]]
// CHECK: memref.dealloc %[[w1]]
// CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
// CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w0]], %[[a0]]
// CHECK: memref.dealloc %[[w0]]
// CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
// CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
// CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
// CHECK: memref.dealloc %[[a0]]
Expand Down Expand Up @@ -437,10 +437,10 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
// CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w1]], %[[a1]]
// CHECK: memref.dealloc %[[w1]]
// CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
// CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w0]], %[[a0]]
// CHECK: memref.dealloc %[[w0]]
// CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
// CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
// CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
// CHECK: memref.dealloc %[[a0]]
Expand All @@ -457,9 +457,9 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
// CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[b1]], %[[a3]]
// CHECK: memref.dealloc %[[b1]]
// CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
// CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[b0]], %[[a2]]
// CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
// CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
// CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]]
// CHECK: memref.dealloc %[[a2]]
Expand Down

0 comments on commit ba9d886

Please sign in to comment.