diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 2d57a341c3eb..553b814f4584 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -656,16 +656,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { return false; } -// bufferization.to_buffer is not allowed to change the rank. -static void ensureToBufferOpIsValid(Value tensor, Type memrefType) { -#ifndef NDEBUG - auto rankedTensorType = llvm::dyn_cast(tensor.getType()); - assert((!rankedTensorType || llvm::cast(memrefType).getRank() == - rankedTensorType.getRank()) && - "to_buffer would be invalid: mismatching ranks"); -#endif -} - FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options) { #ifndef NDEBUG @@ -683,7 +673,7 @@ FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, FailureOr bufferType = getBufferType(value, options); if (failed(bufferType)) return failure(); - ensureToBufferOpIsValid(value, *bufferType); + return rewriter .create(value.getLoc(), *bufferType, value) .getResult(); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index dc5c15877271..e0b057e781f1 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -73,9 +73,6 @@ struct BuiltinTensorExternalModel mlir::LogicalResult verifyCompatibleBufferType( mlir::Type tensor, BufferLikeType bufferType, llvm::function_ref emitError) const { - assert(isa(tensor) && "expected tensor type"); - assert(isa(bufferType) && "expected memref type"); - auto tensorType = cast(tensor); auto memrefType = cast(bufferType); diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir index 2c8807b66de7..9884b040119d 100644 --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -127,3 +127,63 @@ func.func @invalid_manual_deallocation() { // expected-error @below{{op attribute 'bufferization.manual_deallocation' can be used only on ops that have an allocation and/or free side effect}} arith.constant {bufferization.manual_deallocation} 0 : index } + +// ----- + +func.func @invalid_rank_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x3xf32> + return +} + +// ----- + +func.func @invalid_rank_to_tensor(%b: memref<1x2x3xf32>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %t = bufferization.to_tensor %b + : memref<1x2x3xf32> to tensor<1x2x3x4xf32> + return +} + +// ----- + +func.func @invalid_shape_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x4x3xf32> + return +} + +// ----- + +func.func @invalid_shape_to_tensor(%b: memref<1x2x4x3xf32>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %t = bufferization.to_tensor %b + : memref<1x2x4x3xf32> to tensor<1x2x3x4xf32> + return +} + +// ----- + +func.func @invalid_type_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{element types do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x3x4xf16> + return +} + +// ----- + +func.func @invalid_type_to_tensor(%b: memref<1x2x3x4xf16>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{element types do not match}} + %t2 = bufferization.to_tensor %b + : memref<1x2x3x4xf16> to tensor<1x2x3x4xf32> + return +} diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir index fc6df4a09f70..b0db1bb2d038 100644 --- a/mlir/test/Dialect/Bufferization/ops.mlir +++ b/mlir/test/Dialect/Bufferization/ops.mlir @@ -83,3 +83,40 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>, bufferization.dealloc return %0#0, %0#1 : i1, i1 } + +// CHECK: func.func @test_builtin_custom_builtin_type_conversion +// CHECK-SAME: (%[[t:.*]]: tensor<42xf32>) -> tensor<42xf32> +func.func @test_builtin_custom_builtin_type_conversion(%t: tensor<42xf32>) + -> tensor<42xf32> { + // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]] + // CHECK-SAME: to !test.test_memref<[42], f32> + %buffer = bufferization.to_buffer %t + : tensor<42xf32> to !test.test_memref<[42], f32> + + // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]] + // CHECK-SAME: to tensor<42xf32> + %tensor = bufferization.to_tensor %buffer + : !test.test_memref<[42], f32> to tensor<42xf32> + + // CHECK: return %[[tensor]] + return %tensor : tensor<42xf32> +} + +// CHECK: func.func @test_custom_builtin_custom_type_conversion +// CHECK-SAME: (%[[t:.*]]: !test.test_tensor<[42], f32>) +// CHECK-SAME: -> !test.test_tensor<[42], f32> +func.func @test_custom_builtin_custom_type_conversion(%t: !test.test_tensor<[42], f32>) + -> !test.test_tensor<[42], f32> { + // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]] + // CHECK-SAME: to memref<42xf32> + %buffer = bufferization.to_buffer %t + : !test.test_tensor<[42], f32> to memref<42xf32> + + // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]] + // CHECK-SAME: to !test.test_tensor<[42], f32> + %tensor = bufferization.to_tensor %buffer + : memref<42xf32> to !test.test_tensor<[42], f32> + + // CHECK: return %[[tensor]] + return %tensor : !test.test_tensor<[42], f32> +} diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 496ff47a06e5..861217f24ddc 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -584,11 +584,17 @@ TestTensorType::getBufferType( ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType( ::mlir::bufferization::BufferLikeType bufferType, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { - auto testMemref = llvm::dyn_cast(bufferType); - if (!testMemref) - return emitError() << "expected TestMemrefType"; + if (auto testMemref = llvm::dyn_cast(bufferType)) { + const bool valid = getShape() == testMemref.getShape() && + getElementType() == testMemref.getElementType(); + return mlir::success(valid); + } + + if (auto builtinMemref = llvm::dyn_cast(bufferType)) { + const bool valid = getShape() == builtinMemref.getShape() && + getElementType() == builtinMemref.getElementType(); + return mlir::success(valid); + } - const bool valid = getShape() == testMemref.getShape() && - getElementType() == testMemref.getElementType(); - return mlir::success(valid); + return emitError() << "expected MemRefType or TestMemrefType"; }