From 965ad79ea7d0b98f905a27785a6fd0091b904218 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 7 Jun 2021 13:44:07 -0400 Subject: [PATCH] [MLIR][MemRef] Only allow fold of cast for the pointer operand, not the value Currently canonicalizations of a store and a cast try to fold all casts into the store. In the case where the operand being stored is itself a cast, this is illegal as the type of the value being stored will change. This PR fixes this by not checking the value for folding with a cast. Depends on https://reviews.llvm.org/D103828 Differential Revision: https://reviews.llvm.org/D103829 --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 7 ++++--- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 7 ++++--- mlir/test/Dialect/Affine/canonicalize.mlir | 12 ++++++++++++ mlir/test/Dialect/MemRef/canonicalize.mlir | 10 ++++++++++ 4 files changed, 30 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index ef990b70f3575..480b53811483f 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -942,11 +942,12 @@ void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results, /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref.cast /// into the root operation directly. -static LogicalResult foldMemRefCast(Operation *op) { +static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); - if (cast && !cast.getOperand().getType().isa()) { + if (cast && operand.get() != ignore && + !cast.getOperand().getType().isa()) { operand.set(cast.getOperand()); folded = true; } @@ -2270,7 +2271,7 @@ void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult AffineStoreOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// store(memrefcast) -> store - return foldMemRefCast(*this); + return foldMemRefCast(*this, getValueToStore()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index a4ab6c1d0859f..f20234bd1d686 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -73,11 +73,12 @@ static void dispatchIndexOpFoldResults(ArrayRef ofrs, /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref.cast /// into the root operation directly. -static LogicalResult foldMemRefCast(Operation *op) { +static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); - if (cast && !cast.getOperand().getType().isa()) { + if (cast && operand.get() != inner && + !cast.getOperand().getType().isa()) { operand.set(cast.getOperand()); folded = true; } @@ -1425,7 +1426,7 @@ static LogicalResult verify(StoreOp op) { LogicalResult StoreOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// store(memrefcast) -> store - return foldMemRefCast(*this); + return foldMemRefCast(*this, getValueToStore()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index 0a47285e18c49..3d6bd57c27ffc 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -924,3 +924,15 @@ func @compose_into_affine_vector_load_vector_store(%A : memref<1024xf32>, %u : i } return } + +// ----- + +// CHECK-LABEL: func @no_fold_of_store +// CHECK: %[[cst:.+]] = memref.cast %arg +// CHECK: affine.store %[[cst]] +func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref>) { + %0 = memref.cast %arg : memref<32xi8> to memref + affine.store %0, %holder[] : memref> + return +} + diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 354be2237ec30..140cd43ede147 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -206,4 +206,14 @@ func @dim_of_sized_view(%arg : memref, %size: index) -> index { return %1 : index } +// ----- + +// CHECK-LABEL: func @no_fold_of_store +// CHECK: %[[cst:.+]] = memref.cast %arg +// CHECK: memref.store %[[cst]] +func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref>) { + %0 = memref.cast %arg : memref<32xi8> to memref + memref.store %0, %holder[] : memref> + return +}