diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h index b7abcdea10a2a..c4f2cf2413165 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -50,8 +50,9 @@ namespace memref { /// This is a common utility used for patterns of the form /// "someop(memref.cast) -> someop". It folds the source of any memref.cast -/// into the root operation directly. -LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr); +/// into the root operation directly. Operands in `ignoredOperands` are excluded +/// from folding. +LogicalResult foldMemRefCast(Operation *op, ValueRange ignoredOperands = {}); /// Return an unranked/ranked tensor type for the given unranked/ranked memref /// type. diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1035d7cb46e6e..6b82a550668b2 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -41,12 +41,14 @@ Operation *MemRefDialect::materializeConstant(OpBuilder &builder, /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref.cast -/// into the root operation directly. -LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { +/// into the root operation directly. Operands in `ignoredOperands` are excluded +/// from folding. +LogicalResult mlir::memref::foldMemRefCast(Operation *op, + ValueRange ignoredOperands) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); - if (cast && operand.get() != inner && + if (cast && !llvm::is_contained(ignoredOperands, operand.get()) && !llvm::isa(cast.getOperand().getType())) { operand.set(cast.getOperand()); folded = true;