From d7f3bd23f7e4c18d0c5598deb950d76ed002b73f Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 8 Mar 2024 15:53:57 -0500 Subject: [PATCH] [Task] : Add comments + enhance check for index in parent block of reshape. --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 25 ++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index f69a10334050b..3594b9669e3c6 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1083,17 +1083,34 @@ struct DimOfMemRefReshape : public OpRewritePattern { return rewriter.notifyMatchFailure( dim, "Dim op is not defined by a reshape op."); + // dim of a memref reshape can be folded if dim.getIndex() dominates the + // reshape. Instead of using `DominanceInfo` (which is usually costly) we + // cheaply check that either of the following conditions hold: + // 1. dim.getIndex() is defined in the same block as reshape but before + // reshape. + // 2. dim.getIndex() is defined in a parent block of + // reshape. + + // Check condition 1 if (dim.getIndex().getParentBlock() == reshape->getBlock()) { if (auto *definingOp = dim.getIndex().getDefiningOp()) { - if (reshape->isBeforeInBlock(definingOp)) + if (reshape->isBeforeInBlock(definingOp)) { return rewriter.notifyMatchFailure( dim, "dim.getIndex is not defined before reshape in the same block."); - } // else dim.getIndex is a block argument to reshape->getBlock - } else if (!dim.getIndex().getParentRegion()->isProperAncestor( - reshape->getParentRegion())) + } + } // else dim.getIndex is a block argument to reshape->getBlock and + // dominates reshape + } // Check condition 2 + else if (dim->getBlock() != reshape->getBlock() && + !dim.getIndex().getParentRegion()->isProperAncestor( + reshape->getParentRegion())) { + // If dim and reshape are in the same block but dim.getIndex() isn't, we + // already know dim.getIndex() dominates reshape without calling + // `isProperAncestor` return rewriter.notifyMatchFailure( dim, "dim.getIndex does not dominate reshape."); + } // Place the load directly after the reshape to ensure that the shape memref // was not mutated.