Skip to content

Commit

Permalink
[MLIR][MemRef] Only allow fold of cast for the pointer operand, not t…
Browse files Browse the repository at this point in the history
…he value

Currently canonicalizations of a store and a cast try to fold all casts into the store.

In the case where the operand being stored is itself a cast, this is illegal as the type of the value being stored
will change. This PR fixes this by not checking the value for folding with a cast.

Depends on https://reviews.llvm.org/D103828

Differential Revision: https://reviews.llvm.org/D103829
  • Loading branch information
wsmoses committed Jun 8, 2021
1 parent 49454eb commit 965ad79
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Expand Up @@ -942,11 +942,12 @@ void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
/// into the root operation directly.
static LogicalResult foldMemRefCast(Operation *op) {
static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto cast = operand.get().getDefiningOp<memref::CastOp>();
if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
if (cast && operand.get() != ignore &&
!cast.getOperand().getType().isa<UnrankedMemRefType>()) {
operand.set(cast.getOperand());
folded = true;
}
Expand Down Expand Up @@ -2270,7 +2271,7 @@ void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
return foldMemRefCast(*this);
return foldMemRefCast(*this, getValueToStore());
}

//===----------------------------------------------------------------------===//
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Expand Up @@ -73,11 +73,12 @@ static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
/// into the root operation directly.
static LogicalResult foldMemRefCast(Operation *op) {
static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto cast = operand.get().getDefiningOp<CastOp>();
if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
if (cast && operand.get() != inner &&
!cast.getOperand().getType().isa<UnrankedMemRefType>()) {
operand.set(cast.getOperand());
folded = true;
}
Expand Down Expand Up @@ -1425,7 +1426,7 @@ static LogicalResult verify(StoreOp op) {
LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
return foldMemRefCast(*this);
return foldMemRefCast(*this, getValueToStore());
}

//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Affine/canonicalize.mlir
Expand Up @@ -924,3 +924,15 @@ func @compose_into_affine_vector_load_vector_store(%A : memref<1024xf32>, %u : i
}
return
}

// -----

// CHECK-LABEL: func @no_fold_of_store
// CHECK: %[[cst:.+]] = memref.cast %arg
// CHECK: affine.store %[[cst]]
func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
%0 = memref.cast %arg : memref<32xi8> to memref<?xi8>
affine.store %0, %holder[] : memref<memref<?xi8>>
return
}

10 changes: 10 additions & 0 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Expand Up @@ -206,4 +206,14 @@ func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index {
return %1 : index
}

// -----

// CHECK-LABEL: func @no_fold_of_store
// CHECK: %[[cst:.+]] = memref.cast %arg
// CHECK: memref.store %[[cst]]
func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
%0 = memref.cast %arg : memref<32xi8> to memref<?xi8>
memref.store %0, %holder[] : memref<memref<?xi8>>
return
}

0 comments on commit 965ad79

Please sign in to comment.