From f2b6302fb249d111f806b366ec80162ed8277243 Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Tue, 21 Jan 2025 19:59:14 +0800 Subject: [PATCH 1/2] [mlir][transforms] Process RegionBranchOp with empty region This PR adds process for RegionBranchOp with empty region, such as `scf.if`. --- mlir/lib/Transforms/RemoveDeadValues.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 3e7a0cca31c77..c20c54551cdf8 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -375,6 +375,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`. auto markLiveArgs = [&](DenseMap &liveArgs) { for (Region ®ion : regionBranchOp->getRegions()) { + if (region.empty()) + continue; SmallVector arguments(region.front().getArguments()); BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la); liveArgs[®ion] = regionLiveArgs; @@ -420,6 +422,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, auto markNonForwardedReturnValues = [&](DenseMap &nonForwardedRets) { for (Region ®ion : regionBranchOp->getRegions()) { + if (region.empty()) + continue; Operation *terminator = region.front().getTerminator(); nonForwardedRets[terminator] = BitVector(terminator->getNumOperands(), true); @@ -499,6 +503,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // Recompute `resultsToKeep` and `argsToKeep` based on // `terminatorOperandsToKeep`. for (Region ®ion : regionBranchOp->getRegions()) { + if (region.empty()) + continue; Operation *terminator = region.front().getTerminator(); for (const RegionSuccessor &successor : getSuccessors(®ion)) { Region *successorRegion = successor.getSuccessor(); @@ -547,6 +553,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // Update the terminator operands that need to be kept. for (Region ®ion : regionBranchOp->getRegions()) { + if (region.empty()) + continue; updateOperandsOrTerminatorOperandsToKeep( terminatorOperandsToKeep[region.back().getTerminator()], resultsToKeep, argsToKeep, ®ion); @@ -611,8 +619,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // Do (2.a) and (2.b). for (Region ®ion : regionBranchOp->getRegions()) { - assert(!region.empty() && "expected a non-empty region in an op " - "implementing `RegionBranchOpInterface`"); + if (region.empty()) + continue; BitVector argsToRemove = argsToKeep[®ion].flip(); cl.blocks.push_back({®ion.front(), argsToRemove}); collectNonLiveValues(nonLiveSet, region.front().getArguments(), @@ -621,6 +629,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // Do (2.c). for (Region ®ion : regionBranchOp->getRegions()) { + if (region.empty()) + continue; Operation *terminator = region.front().getTerminator(); cl.operands.push_back( {terminator, terminatorOperandsToKeep[terminator].flip()}); From 31aecdac1c36b0d88b04f975e883f734e8030454 Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Wed, 22 Jan 2025 14:36:37 +0800 Subject: [PATCH 2/2] add test --- mlir/test/Transforms/remove-dead-values.mlir | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index fe7bcbc7c490b..e549926b90456 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -408,6 +408,22 @@ func.func @main(%arg3 : i32, %arg4 : i1) { // ----- +// The scf.if operation represents an if-then-else construct for conditionally +// executing two regions of code. The 'the' region has exactly 1 block, and +// the 'else' region may have 0 or 1 block. This case is to ensure 'else' region +// with 0 block not crash. + +// CHECK-LABEL: func.func @clean_region_branch_op_with_empty_region +func.func @clean_region_branch_op_with_empty_region(%arg0: i1, %arg1: memref) { + %cst = arith.constant 1.000000e+00 : f32 + scf.if %arg0 { + memref.store %cst, %arg1[] : memref + } + return +} + +// ----- + #map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> func.func @kernel(%arg0: memref<18xf32>) { %c1 = arith.constant 1 : index