diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 404b2aacf1450..f31811ad7b98e 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2291,6 +2291,21 @@ struct ReinterpretCastOpConstantFolder [](OpFoldResult ofr) { return isa(ofr); })) return failure(); + // Do not fold if the offset is a negative constant; ViewLikeInterface + // verifies that static offsets are non-negative. + if (auto cst = getConstantIntValue(offsets[0])) + if (*cst < 0) + return rewriter.notifyMatchFailure( + op, "negative constant offset is invalid"); + + // Do not fold if any size is a negative constant; MemRefType::get asserts + // non-negative static sizes. + for (OpFoldResult sizeOfr : sizes) + if (auto cst = getConstantIntValue(sizeOfr)) + if (*cst < 0) + return rewriter.notifyMatchFailure( + op, "negative constant size is invalid"); + auto newReinterpretCast = ReinterpretCastOp::create( rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 92754dc919695..fb1e7d00feb47 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -1287,6 +1287,69 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : me // ----- +// Check that reinterpret_cast with a negative constant size is not folded. +// Folding would attempt to create a MemRefType with a negative static dimension, +// which triggers an assertion in MemRefType::get (issue #188407). +// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_size +// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[SZ:.*]] = arith.constant -1 : index +// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [%[[C1]], %[[SZ]]], strides: [%[[SZ]], %[[C1]]] +func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> memref> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %sz = arith.constant -1 : index + %output = memref.reinterpret_cast %arg0 to + offset: [%c0], sizes: [%c1, %sz], strides: [%sz, %c1] + : memref<2x3xf32> to memref> + return %output : memref> +} + +// ----- + +// Check that reinterpret_cast with a negative constant offset is not folded. +// Folding would create an op with a static negative offset, which violates the +// ViewLikeInterface constraint that offsets must be non-negative. +// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_offset +// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>) +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[NEG:.*]] = arith.constant -1 : index +// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[C1]], %[[C2]]], strides: [%[[C2]], %[[C1]]] +func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> memref> { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %neg = arith.constant -1 : index + %output = memref.reinterpret_cast %arg0 to + offset: [%neg], sizes: [%c1, %c2], strides: [%c2, %c1] + : memref<2x3xf32> to memref> + return %output : memref> +} + +// ----- + +// Check that reinterpret_cast with a negative constant stride IS folded. +// Negative strides are valid in MemRef layouts (e.g. reverse iteration), +// and the ViewLikeInterface places no non-negativity constraint on strides. +// CHECK-LABEL: func @reinterpret_cast_fold_negative_stride +// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>) +// CHECK-NOT: arith.constant +// CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [1, 2], strides: [-1, 1] +// CHECK: memref.cast %[[RC]] +func.func @reinterpret_cast_fold_negative_stride(%arg0: memref<2x3xf32>) -> memref> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %neg = arith.constant -1 : index + %output = memref.reinterpret_cast %arg0 to + offset: [%c0], sizes: [%c1, %c2], strides: [%neg, %c1] + : memref<2x3xf32> to memref> + return %output : memref> +} + +// ----- + func.func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>, %arg1 : index) -> memref> { %c0 = arith.constant 0 : index