diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index d29150a7403f9..6c5719ce6df8e 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -378,10 +378,18 @@ static LogicalResult getFuncOpsOrderedByCalls( /// Helper function that extracts the source from a memref.cast. If the given /// value is not a memref.cast result, simply returns the given value. +/// Only unpacks casts where the source is at least as specific as the result +/// (i.e., does not unpack casts from unranked to ranked memref, which would +/// downgrade the type). static Value unpackCast(Value v) { auto castOp = v.getDefiningOp(); if (!castOp) return v; + // Do not unpack a cast from unranked to ranked memref: folding would + // downgrade the function return type from ranked to unranked. + if (isa(castOp.getSource().getType()) && + isa(v.getType())) + return v; return castOp.getSource(); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index 8db1ebb87a1e5..d5cb7a0f14f5a 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -884,3 +884,24 @@ func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>) // CHECK: return %[[out]] return %out : !test.test_tensor<[4, 8], f64> } + +// ----- + +// Test that foldMemRefCasts does not downgrade a ranked return type to unranked +// when the return value is produced by a memref.cast from unranked to ranked. +// CHECK-LABEL: func.func @ranked_return_via_unranked_call( +// CHECK-SAME: %[[arg:.*]]: memref<64x20x40xf32 +// CHECK-SAME: ) -> memref<64x20x40xf32 +func.func @ranked_return_via_unranked_call(%arg0: tensor<64x20x40xf32>) -> tensor<64x20x40xf32> { + // CHECK: %[[cast:.*]] = memref.cast %[[arg]] + // CHECK-SAME: to memref<*xf32> + %u = tensor.cast %arg0 : tensor<64x20x40xf32> to tensor<*xf32> + // CHECK: %[[call:.*]] = call @relu_unranked(%[[cast]]) + %r = call @relu_unranked(%u) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: %[[cast2:.*]] = memref.cast %[[call]] + // CHECK-SAME: to memref<64x20x40xf32 + %b = tensor.cast %r : tensor<*xf32> to tensor<64x20x40xf32> + // CHECK: return %[[cast2]] + return %b : tensor<64x20x40xf32> +} +func.func private @relu_unranked(tensor<*xf32>) -> tensor<*xf32>