Skip to content

Commit c1067fd

Browse files
[mlir][bufferization] Refine tensor-buffer compatibility checks
Generally, to_tensor and to_buffer already perform sufficient verification. However, there are some unnecessary strict constraints: * builtin tensor requires its buffer counterpart to always be memref * to_buffer on ranked tensor requires to always return memref These checks are assertions (i.e. preconditions), however, they actually prevent an apparently useful bufferization where builtin tensors could become custom buffers. Lift these assertions, maintaining the verification procedure unchanged, to allow builtin -> custom bufferizations at operation boundary level.
1 parent 68a4af6 commit c1067fd

File tree

4 files changed

+40
-21
lines changed

4 files changed

+40
-21
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -680,16 +680,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
680680
return false;
681681
}
682682

683-
// bufferization.to_buffer is not allowed to change the rank.
684-
static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
685-
#ifndef NDEBUG
686-
auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
687-
assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
688-
rankedTensorType.getRank()) &&
689-
"to_buffer would be invalid: mismatching ranks");
690-
#endif
691-
}
692-
693683
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
694684
const BufferizationOptions &options,
695685
const BufferizationState &state) {
@@ -708,7 +698,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
708698
FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
709699
if (failed(bufferType))
710700
return failure();
711-
ensureToBufferOpIsValid(value, *bufferType);
701+
712702
return bufferization::ToBufferOp::create(rewriter, value.getLoc(),
713703
*bufferType, value)
714704
.getResult();

mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ struct BuiltinTensorExternalModel
5454
mlir::LogicalResult verifyCompatibleBufferType(
5555
mlir::Type tensor, BufferLikeType bufferType,
5656
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
57-
assert(isa<TensorType>(tensor) && "expected tensor type");
58-
assert(isa<BaseMemRefType>(bufferType) && "expected memref type");
59-
6057
auto tensorType = cast<ShapedType>(tensor);
6158
auto memrefType = cast<ShapedType>(bufferType);
6259

mlir/test/Dialect/Bufferization/Transforms/tensorlike-bufferlike.mlir

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-tensorlike-bufferlike -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-tensorlike-bufferlike -verify-diagnostics -split-input-file | FileCheck %s
22

33
// CHECK: func.func @builtin_unranked
44
// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_buffer_like"}}
@@ -35,3 +35,29 @@ func.func @custom_memref(%t: !test.test_memref<[42], f32>) -> ()
3535
{
3636
return
3737
}
38+
39+
// -----
40+
41+
// CHECK: func.func @builtin_custom_builtin_roundtrip
42+
// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_tensor_like"}}
43+
func.func @builtin_custom_builtin_roundtrip(%t: tensor<42xf32>)
44+
-> tensor<42xf32> {
45+
%buffer = bufferization.to_buffer %t
46+
: tensor<42xf32> to !test.test_memref<[42], f32>
47+
%tensor = bufferization.to_tensor %buffer
48+
: !test.test_memref<[42], f32> to tensor<42xf32>
49+
return %tensor : tensor<42xf32>
50+
}
51+
52+
// -----
53+
54+
// CHECK: func.func @custom_builtin_custom_roundtrip
55+
// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_tensor_like"}}
56+
func.func @custom_builtin_custom_roundtrip(%t: !test.test_tensor<[42], f32>)
57+
-> !test.test_tensor<[42], f32> {
58+
%buffer = bufferization.to_buffer %t
59+
: !test.test_tensor<[42], f32> to memref<42xf32>
60+
%tensor = bufferization.to_tensor %buffer
61+
: memref<42xf32> to !test.test_tensor<[42], f32>
62+
return %tensor : !test.test_tensor<[42], f32>
63+
}

mlir/test/lib/Dialect/Test/TestTypes.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -569,11 +569,17 @@ TestTensorType::getBufferType(
569569
::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
570570
::mlir::bufferization::BufferLikeType bufferType,
571571
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
572-
auto testMemref = dyn_cast<TestMemrefType>(bufferType);
573-
if (!testMemref)
574-
return emitError() << "expected TestMemrefType";
572+
if (auto testMemref = dyn_cast<TestMemrefType>(bufferType)) {
573+
const bool valid = getShape() == testMemref.getShape() &&
574+
getElementType() == testMemref.getElementType();
575+
return mlir::success(valid);
576+
}
577+
578+
if (auto builtinMemref = dyn_cast<MemRefType>(bufferType)) {
579+
const bool valid = getShape() == builtinMemref.getShape() &&
580+
getElementType() == builtinMemref.getElementType();
581+
return mlir::success(valid);
582+
}
575583

576-
const bool valid = getShape() == testMemref.getShape() &&
577-
getElementType() == testMemref.getElementType();
578-
return mlir::success(valid);
584+
return emitError() << "expected TestMemrefType";
579585
}

0 commit comments

Comments
 (0)