diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index e0cf353da207f..9b11270e7bbe2 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -680,16 +680,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { return false; } -// bufferization.to_buffer is not allowed to change the rank. -static void ensureToBufferOpIsValid(Value tensor, Type memrefType) { -#ifndef NDEBUG - auto rankedTensorType = llvm::dyn_cast(tensor.getType()); - assert((!rankedTensorType || llvm::cast(memrefType).getRank() == - rankedTensorType.getRank()) && - "to_buffer would be invalid: mismatching ranks"); -#endif -} - FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state) { @@ -708,7 +698,7 @@ FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, FailureOr bufferType = getBufferType(value, options, state); if (failed(bufferType)) return failure(); - ensureToBufferOpIsValid(value, *bufferType); + return bufferization::ToBufferOp::create(rewriter, value.getLoc(), *bufferType, value) .getResult(); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index d6c3cd62ee742..bd177ba1afccd 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -54,9 +54,6 @@ struct BuiltinTensorExternalModel mlir::LogicalResult verifyCompatibleBufferType( mlir::Type tensor, BufferLikeType bufferType, llvm::function_ref emitError) const { - assert(isa(tensor) && "expected tensor type"); - assert(isa(bufferType) && "expected memref type"); - auto tensorType = cast(tensor); auto memrefType = cast(bufferType); diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir index 2c8807b66de74..9884b040119d0 100644 --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -127,3 +127,63 @@ func.func @invalid_manual_deallocation() { // expected-error @below{{op attribute 'bufferization.manual_deallocation' can be used only on ops that have an allocation and/or free side effect}} arith.constant {bufferization.manual_deallocation} 0 : index } + +// ----- + +func.func @invalid_rank_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x3xf32> + return +} + +// ----- + +func.func @invalid_rank_to_tensor(%b: memref<1x2x3xf32>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %t = bufferization.to_tensor %b + : memref<1x2x3xf32> to tensor<1x2x3x4xf32> + return +} + +// ----- + +func.func @invalid_shape_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x4x3xf32> + return +} + +// ----- + +func.func @invalid_shape_to_tensor(%b: memref<1x2x4x3xf32>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %t = bufferization.to_tensor %b + : memref<1x2x4x3xf32> to tensor<1x2x3x4xf32> + return +} + +// ----- + +func.func @invalid_type_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{element types do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x3x4xf16> + return +} + +// ----- + +func.func @invalid_type_to_tensor(%b: memref<1x2x3x4xf16>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{element types do not match}} + %t2 = bufferization.to_tensor %b + : memref<1x2x3x4xf16> to tensor<1x2x3x4xf32> + return +} diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir index fc6df4a09f706..b0db1bb2d0389 100644 --- a/mlir/test/Dialect/Bufferization/ops.mlir +++ b/mlir/test/Dialect/Bufferization/ops.mlir @@ -83,3 +83,40 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>, bufferization.dealloc return %0#0, %0#1 : i1, i1 } + +// CHECK: func.func @test_builtin_custom_builtin_type_conversion +// CHECK-SAME: (%[[t:.*]]: tensor<42xf32>) -> tensor<42xf32> +func.func @test_builtin_custom_builtin_type_conversion(%t: tensor<42xf32>) + -> tensor<42xf32> { + // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]] + // CHECK-SAME: to !test.test_memref<[42], f32> + %buffer = bufferization.to_buffer %t + : tensor<42xf32> to !test.test_memref<[42], f32> + + // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]] + // CHECK-SAME: to tensor<42xf32> + %tensor = bufferization.to_tensor %buffer + : !test.test_memref<[42], f32> to tensor<42xf32> + + // CHECK: return %[[tensor]] + return %tensor : tensor<42xf32> +} + +// CHECK: func.func @test_custom_builtin_custom_type_conversion +// CHECK-SAME: (%[[t:.*]]: !test.test_tensor<[42], f32>) +// CHECK-SAME: -> !test.test_tensor<[42], f32> +func.func @test_custom_builtin_custom_type_conversion(%t: !test.test_tensor<[42], f32>) + -> !test.test_tensor<[42], f32> { + // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]] + // CHECK-SAME: to memref<42xf32> + %buffer = bufferization.to_buffer %t + : !test.test_tensor<[42], f32> to memref<42xf32> + + // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]] + // CHECK-SAME: to !test.test_tensor<[42], f32> + %tensor = bufferization.to_tensor %buffer + : memref<42xf32> to !test.test_tensor<[42], f32> + + // CHECK: return %[[tensor]] + return %tensor : !test.test_tensor<[42], f32> +} diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 614121f1d43dd..9cf64a896d28a 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -569,11 +569,17 @@ TestTensorType::getBufferType( ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType( ::mlir::bufferization::BufferLikeType bufferType, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { - auto testMemref = dyn_cast(bufferType); - if (!testMemref) - return emitError() << "expected TestMemrefType"; + if (auto testMemref = dyn_cast(bufferType)) { + const bool valid = getShape() == testMemref.getShape() && + getElementType() == testMemref.getElementType(); + return mlir::success(valid); + } + + if (auto builtinMemref = dyn_cast(bufferType)) { + const bool valid = getShape() == builtinMemref.getShape() && + getElementType() == builtinMemref.getElementType(); + return mlir::success(valid); + } - const bool valid = getShape() == testMemref.getShape() && - getElementType() == testMemref.getElementType(); - return mlir::success(valid); + return emitError() << "expected MemRefType or TestMemrefType"; }