Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,16 +680,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
return false;
}

// bufferization.to_buffer is not allowed to change the rank.
static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
#ifndef NDEBUG
auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
rankedTensorType.getRank()) &&
"to_buffer would be invalid: mismatching ranks");
#endif
}
Copy link
Contributor Author

@andrey-golubev andrey-golubev Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this actually blocked our downstream to switch to mlir::bufferization::getBuffer() implementation. from what I can tell, this check is superseded a long time ago by ToBufferOp's verifier (that checks, for builtins, both shape - and thus rank - and element type).

unfortunately, testing this specifically is a pain: i need a fairly large setup to showcase how this fails. I hope this is fine that I just delete it and provide some general tests that exercise op's verifiers instead.


FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options,
const BufferizationState &state) {
Expand All @@ -708,7 +698,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
if (failed(bufferType))
return failure();
ensureToBufferOpIsValid(value, *bufferType);

return bufferization::ToBufferOp::create(rewriter, value.getLoc(),
*bufferType, value)
.getResult();
Expand Down
3 changes: 0 additions & 3 deletions mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ struct BuiltinTensorExternalModel
mlir::LogicalResult verifyCompatibleBufferType(
mlir::Type tensor, BufferLikeType bufferType,
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
assert(isa<TensorType>(tensor) && "expected tensor type");
assert(isa<BaseMemRefType>(bufferType) && "expected memref type");

auto tensorType = cast<ShapedType>(tensor);
auto memrefType = cast<ShapedType>(bufferType);

Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/Bufferization/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,63 @@ func.func @invalid_manual_deallocation() {
// expected-error @below{{op attribute 'bufferization.manual_deallocation' can be used only on ops that have an allocation and/or free side effect}}
arith.constant {bufferization.manual_deallocation} 0 : index
}

// -----
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthias-springer I have extended these tests with rank / shape / element type checks for builtin types. now, this does "confirm" that ensureToBufferOpIsValid() removed in ::getBuffer() is indeed unnecessary.


func.func @invalid_rank_to_buffer(%t: tensor<1x2x3x4xf32>) {
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
// expected-error @below{{shapes do not match}}
%b = bufferization.to_buffer %t
: tensor<1x2x3x4xf32> to memref<1x2x3xf32>
return
}

// -----

func.func @invalid_rank_to_tensor(%b: memref<1x2x3xf32>) {
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
// expected-error @below{{shapes do not match}}
%t = bufferization.to_tensor %b
: memref<1x2x3xf32> to tensor<1x2x3x4xf32>
return
}

// -----

func.func @invalid_shape_to_buffer(%t: tensor<1x2x3x4xf32>) {
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
// expected-error @below{{shapes do not match}}
%b = bufferization.to_buffer %t
: tensor<1x2x3x4xf32> to memref<1x2x4x3xf32>
return
}

// -----

func.func @invalid_shape_to_tensor(%b: memref<1x2x4x3xf32>) {
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
// expected-error @below{{shapes do not match}}
%t = bufferization.to_tensor %b
: memref<1x2x4x3xf32> to tensor<1x2x3x4xf32>
return
}

// -----

func.func @invalid_type_to_buffer(%t: tensor<1x2x3x4xf32>) {
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
// expected-error @below{{element types do not match}}
%b = bufferization.to_buffer %t
: tensor<1x2x3x4xf32> to memref<1x2x3x4xf16>
return
}

// -----

func.func @invalid_type_to_tensor(%b: memref<1x2x3x4xf16>) {
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
// expected-error @below{{element types do not match}}
%t2 = bufferization.to_tensor %b
: memref<1x2x3x4xf16> to tensor<1x2x3x4xf32>
return
}
37 changes: 37 additions & 0 deletions mlir/test/Dialect/Bufferization/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,40 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>,
bufferization.dealloc
return %0#0, %0#1 : i1, i1
}

Copy link
Contributor Author

@andrey-golubev andrey-golubev Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: moved the tests that were originally in tensorlike-bufferlike.mlir to ops.mlir. this seems to be a better place since the tests essentially check what operations can do.

// CHECK: func.func @test_builtin_custom_builtin_type_conversion
// CHECK-SAME: (%[[t:.*]]: tensor<42xf32>) -> tensor<42xf32>
func.func @test_builtin_custom_builtin_type_conversion(%t: tensor<42xf32>)
-> tensor<42xf32> {
// CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
// CHECK-SAME: to !test.test_memref<[42], f32>
%buffer = bufferization.to_buffer %t
: tensor<42xf32> to !test.test_memref<[42], f32>

// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
// CHECK-SAME: to tensor<42xf32>
%tensor = bufferization.to_tensor %buffer
: !test.test_memref<[42], f32> to tensor<42xf32>

// CHECK: return %[[tensor]]
return %tensor : tensor<42xf32>
}

// CHECK: func.func @test_custom_builtin_custom_type_conversion
// CHECK-SAME: (%[[t:.*]]: !test.test_tensor<[42], f32>)
// CHECK-SAME: -> !test.test_tensor<[42], f32>
func.func @test_custom_builtin_custom_type_conversion(%t: !test.test_tensor<[42], f32>)
-> !test.test_tensor<[42], f32> {
// CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
// CHECK-SAME: to memref<42xf32>
%buffer = bufferization.to_buffer %t
: !test.test_tensor<[42], f32> to memref<42xf32>

// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
// CHECK-SAME: to !test.test_tensor<[42], f32>
%tensor = bufferization.to_tensor %buffer
: memref<42xf32> to !test.test_tensor<[42], f32>

// CHECK: return %[[tensor]]
return %tensor : !test.test_tensor<[42], f32>
}
18 changes: 12 additions & 6 deletions mlir/test/lib/Dialect/Test/TestTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,17 @@ TestTensorType::getBufferType(
::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
::mlir::bufferization::BufferLikeType bufferType,
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
auto testMemref = dyn_cast<TestMemrefType>(bufferType);
if (!testMemref)
return emitError() << "expected TestMemrefType";
if (auto testMemref = dyn_cast<TestMemrefType>(bufferType)) {
const bool valid = getShape() == testMemref.getShape() &&
getElementType() == testMemref.getElementType();
return mlir::success(valid);
}

if (auto builtinMemref = dyn_cast<MemRefType>(bufferType)) {
const bool valid = getShape() == builtinMemref.getShape() &&
getElementType() == builtinMemref.getElementType();
return mlir::success(valid);
}

const bool valid = getShape() == testMemref.getShape() &&
getElementType() == testMemref.getElementType();
return mlir::success(valid);
return emitError() << "expected MemRefType or TestMemrefType";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to trigger this in the test above?

Copy link
Contributor Author

@andrey-golubev andrey-golubev Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this specifically - not really.
the bulk of the change is for allowing a mix of builtin memref and test memref (this is what the test above exercises).
the emit-error message is changed just to better align with the new logic (so that the error, if it happens, is less ambiguous). i don't really intend to test code that's written for the sake of testing :)

}
Loading